fix validators

fixed random_apply tests

fix validators

fixed random_apply tests

fix engine validation
This commit is contained in:
nhussain 2020-06-19 09:58:38 -04:00
parent 915ddd25dd
commit 6c37ea3be0
36 changed files with 1136 additions and 1759 deletions

View File

@ -0,0 +1,342 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
General Validators.
"""
import inspect
from multiprocessing import cpu_count
import os
import numpy as np
from ..engine import samplers
# POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1
UINT8_MAX = 255
UINT8_MIN = 0
UINT32_MAX = 4294967295
UINT32_MIN = 0
UINT64_MAX = 18446744073709551615
UINT64_MIN = 0
INT32_MAX = 2147483647
INT32_MIN = -2147483648
INT64_MAX = 9223372036854775807
INT64_MIN = -9223372036854775808
FLOAT_MAX_INTEGER = 16777216
FLOAT_MIN_INTEGER = -16777216
DOUBLE_MAX_INTEGER = 9007199254740992
DOUBLE_MIN_INTEGER = -9007199254740992
valid_detype = [
"bool", "int8", "int16", "int32", "int64", "uint8", "uint16",
"uint32", "uint64", "float16", "float32", "float64", "string"
]
def pad_arg_name(arg_name):
if arg_name != "":
arg_name = arg_name + " "
return arg_name
def check_value(value, valid_range, arg_name=""):
arg_name = pad_arg_name(arg_name)
if value < valid_range[0] or value > valid_range[1]:
raise ValueError(
"Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0],
valid_range[1]))
def check_range(values, valid_range, arg_name=""):
arg_name = pad_arg_name(arg_name)
if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]:
raise ValueError(
"Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0],
valid_range[1]))
def check_positive(value, arg_name=""):
arg_name = pad_arg_name(arg_name)
if value <= 0:
raise ValueError("Input {0}must be greater than 0.".format(arg_name))
def check_positive_float(value, arg_name=""):
arg_name = pad_arg_name(arg_name)
type_check(value, (float,), arg_name)
check_positive(value, arg_name)
def check_2tuple(value, arg_name=""):
if not (isinstance(value, tuple) and len(value) == 2):
raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name))
def check_uint8(value, arg_name=""):
type_check(value, (int,), arg_name)
check_value(value, [UINT8_MIN, UINT8_MAX])
def check_uint32(value, arg_name=""):
type_check(value, (int,), arg_name)
check_value(value, [UINT32_MIN, UINT32_MAX])
def check_pos_int32(value, arg_name=""):
type_check(value, (int,), arg_name)
check_value(value, [POS_INT_MIN, INT32_MAX])
def check_uint64(value, arg_name=""):
type_check(value, (int,), arg_name)
check_value(value, [UINT64_MIN, UINT64_MAX])
def check_pos_int64(value, arg_name=""):
type_check(value, (int,), arg_name)
check_value(value, [UINT64_MIN, INT64_MAX])
def check_pos_float32(value, arg_name=""):
check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name)
def check_pos_float64(value, arg_name=""):
check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER], arg_name)
def check_valid_detype(type_):
if type_ not in valid_detype:
raise ValueError("Unknown column type")
return True
def check_columns(columns, name):
type_check(columns, (list, str), name)
if isinstance(columns, list):
if not columns:
raise ValueError("Column names should not be empty")
col_names = ["col_{0}".format(i) for i in range(len(columns))]
type_check_list(columns, (str,), col_names)
def parse_user_args(method, *args, **kwargs):
"""
Parse user arguments in a function
Args:
method (method): a callable function
*args: user passed args
**kwargs: user passed kwargs
Returns:
user_filled_args (list): values of what the user passed in for the arguments,
ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed.
"""
sig = inspect.signature(method)
if 'self' in sig.parameters or 'cls' in sig.parameters:
ba = sig.bind(method, *args, **kwargs)
ba.apply_defaults()
params = list(sig.parameters.keys())[1:]
else:
ba = sig.bind(*args, **kwargs)
ba.apply_defaults()
params = list(sig.parameters.keys())
user_filled_args = [ba.arguments.get(arg_value) for arg_value in params]
return user_filled_args, ba.arguments
def type_check_list(args, types, arg_names):
"""
Check the type of each parameter in the list
Args:
args (list, tuple): a list or tuple of any variable
types (tuple): tuple of all valid types for arg
arg_names (list, tuple of str): the names of args
Returns:
Exception: when the type is not correct, otherwise nothing
"""
type_check(args, (list, tuple,), arg_names)
if len(args) != len(arg_names):
raise ValueError("List of arguments is not the same length as argument_names.")
for arg, arg_name in zip(args, arg_names):
type_check(arg, types, arg_name)
def type_check(arg, types, arg_name):
"""
Check the type of the parameter
Args:
arg : any variable
types (tuple): tuple of all valid types for arg
arg_name (str): the name of arg
Returns:
Exception: when the type is not correct, otherwise nothing
"""
# handle special case of booleans being a subclass of ints
print_value = '\"\"' if repr(arg) == repr('') else arg
if int in types and bool not in types:
if isinstance(arg, bool):
raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types))
if not isinstance(arg, types):
raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types))
def check_filename(path):
"""
check the filename in the path
Args:
path (str): the path
Returns:
Exception: when error
"""
if not isinstance(path, str):
raise TypeError("path: {} is not string".format(path))
filename = os.path.basename(path)
# '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`',
# '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>',
# '*', '(', '%', ')', '-', '=', '{', '?', '$'
forbidden_symbols = set(r'\/:*?"<>|`&\';')
if set(filename) & forbidden_symbols:
raise ValueError(r"filename should not contains \/:*?\"<>|`&;\'")
if filename.startswith(' ') or filename.endswith(' '):
raise ValueError("filename should not start/end with space")
return True
def check_dir(dataset_dir):
if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK):
raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir))
def check_file(dataset_file):
check_filename(dataset_file)
if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK):
raise ValueError("The file {} does not exist or permission denied!".format(dataset_file))
def check_sampler_shuffle_shard_options(param_dict):
"""
Check for valid shuffle, sampler, num_shards, and shard_id inputs.
Args:
param_dict (dict): param_dict
Returns:
Exception: ValueError or RuntimeError if error
"""
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
type_check(sampler, (type(None), samplers.BuiltinSampler, samplers.Sampler), "sampler")
if sampler is not None:
if shuffle is not None:
raise RuntimeError("sampler and shuffle cannot be specified at the same time.")
if num_shards is not None:
check_pos_int32(num_shards)
if shard_id is None:
raise RuntimeError("num_shards is specified and currently requires shard_id as well.")
check_value(shard_id, [0, num_shards - 1], "shard_id")
if num_shards is None and shard_id is not None:
raise RuntimeError("shard_id is specified but num_shards is not.")
def check_padding_options(param_dict):
"""
Check for valid padded_sample and num_padded of padded samples
Args:
param_dict (dict): param_dict
Returns:
Exception: ValueError or RuntimeError if error
"""
columns_list = param_dict.get('columns_list')
block_reader = param_dict.get('block_reader')
padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded')
if padded_sample is not None:
if num_padded is None:
raise RuntimeError("padded_sample is specified and requires num_padded as well.")
if num_padded < 0:
raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded))
if columns_list is None:
raise RuntimeError("padded_sample is specified and requires columns_list as well.")
for column in columns_list:
if column not in padded_sample:
raise ValueError("padded_sample cannot match columns_list.")
if block_reader:
raise RuntimeError("block_reader and padded_sample cannot be specified at the same time.")
if padded_sample is None and num_padded is not None:
raise RuntimeError("num_padded is specified but padded_sample is not.")
def check_num_parallel_workers(value):
type_check(value, (int,), "num_parallel_workers")
if value < 1 or value > cpu_count():
raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count()))
def check_num_samples(value):
type_check(value, (int,), "num_samples")
check_value(value, [0, INT32_MAX], "num_samples")
def validate_dataset_param_value(param_list, param_dict, param_type):
for param_name in param_list:
if param_dict.get(param_name) is not None:
if param_name == 'num_parallel_workers':
check_num_parallel_workers(param_dict.get(param_name))
if param_name == 'num_samples':
check_num_samples(param_dict.get(param_name))
else:
type_check(param_dict.get(param_name), (param_type,), param_name)
def check_gnn_list_or_ndarray(param, param_name):
"""
Check if the input parameter is list or numpy.ndarray.
Args:
param (list, nd.ndarray): param
param_name (str): param_name
Returns:
Exception: TypeError if error
"""
type_check(param, (list, np.ndarray), param_name)
if isinstance(param, list):
param_names = ["param_{0}".format(i) for i in range(len(param))]
type_check_list(param, (int,), param_names)
elif isinstance(param, np.ndarray):
if not param.dtype == np.int32:
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
param_name, param.dtype))

File diff suppressed because it is too large Load Diff

View File

@ -98,7 +98,7 @@ class Ngram(cde.NgramOp):
""" """
@check_ngram @check_ngram
def __init__(self, n, left_pad=None, right_pad=None, separator=None): def __init__(self, n, left_pad=("", 0), right_pad=("", 0), separator=" "):
super().__init__(ngrams=n, l_pad_len=left_pad[1], r_pad_len=right_pad[1], l_pad_token=left_pad[0], super().__init__(ngrams=n, l_pad_len=left_pad[1], r_pad_len=right_pad[1], l_pad_token=left_pad[0],
r_pad_token=right_pad[0], separator=separator) r_pad_token=right_pad[0], separator=separator)

View File

@ -28,6 +28,7 @@ __all__ = [
"Vocab", "to_str", "to_bytes" "Vocab", "to_str", "to_bytes"
] ]
class Vocab(cde.Vocab): class Vocab(cde.Vocab):
""" """
Vocab object that is used to lookup a word. Vocab object that is used to lookup a word.
@ -38,7 +39,7 @@ class Vocab(cde.Vocab):
@classmethod @classmethod
@check_from_dataset @check_from_dataset
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None,
special_first=None): special_first=True):
""" """
Build a vocab from a dataset. Build a vocab from a dataset.
@ -62,13 +63,21 @@ class Vocab(cde.Vocab):
special_tokens(list, optional): a list of strings, each one is a special token. for example special_tokens(list, optional): a list of strings, each one is a special token. for example
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added). special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens
is specified and special_first is set to None, special_tokens will be prepended (default=None). is specified and special_first is set to True, special_tokens will be prepended (default=True).
Returns: Returns:
Vocab, Vocab object built from dataset. Vocab, Vocab object built from dataset.
""" """
vocab = Vocab() vocab = Vocab()
if columns is None:
columns = []
if not isinstance(columns, list):
columns = [columns]
if freq_range is None:
freq_range = (None, None)
if special_tokens is None:
special_tokens = []
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first) root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first)
for d in root.create_dict_iterator(): for d in root.create_dict_iterator():
if d is not None: if d is not None:
@ -77,7 +86,7 @@ class Vocab(cde.Vocab):
@classmethod @classmethod
@check_from_list @check_from_list
def from_list(cls, word_list, special_tokens=None, special_first=None): def from_list(cls, word_list, special_tokens=None, special_first=True):
""" """
Build a vocab object from a list of word. Build a vocab object from a list of word.
@ -86,29 +95,33 @@ class Vocab(cde.Vocab):
special_tokens(list, optional): a list of strings, each one is a special token. for example special_tokens(list, optional): a list of strings, each one is a special token. for example
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added). special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens
is specified and special_first is set to None, special_tokens will be prepended (default=None). is specified and special_first is set to True, special_tokens will be prepended (default=True).
""" """
if special_tokens is None:
special_tokens = []
return super().from_list(word_list, special_tokens, special_first) return super().from_list(word_list, special_tokens, special_first)
@classmethod @classmethod
@check_from_file @check_from_file
def from_file(cls, file_path, delimiter=None, vocab_size=None, special_tokens=None, special_first=None): def from_file(cls, file_path, delimiter="", vocab_size=None, special_tokens=None, special_first=True):
""" """
Build a vocab object from a list of word. Build a vocab object from a list of word.
Args: Args:
file_path (str): path to the file which contains the vocab list. file_path (str): path to the file which contains the vocab list.
delimiter (str, optional): a delimiter to break up each line in file, the first element is taken to be delimiter (str, optional): a delimiter to break up each line in file, the first element is taken to be
the word (default=None). the word (default="").
vocab_size (int, optional): number of words to read from file_path (default=None, all words are taken). vocab_size (int, optional): number of words to read from file_path (default=None, all words are taken).
special_tokens (list, optional): a list of strings, each one is a special token. for example special_tokens (list, optional): a list of strings, each one is a special token. for example
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added). special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
special_first (bool, optional): whether special_tokens will be prepended/appended to vocab, special_first (bool, optional): whether special_tokens will be prepended/appended to vocab,
If special_tokens is specified and special_first is set to None, If special_tokens is specified and special_first is set to True,
special_tokens will be prepended (default=None). special_tokens will be prepended (default=True).
""" """
if vocab_size is None:
vocab_size = -1
if special_tokens is None:
special_tokens = []
return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first) return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first)
@classmethod @classmethod

View File

@ -17,23 +17,22 @@ validators for text ops
""" """
from functools import wraps from functools import wraps
import mindspore._c_dataengine as cde
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore._c_dataengine as cde
from mindspore._c_expression import typing from mindspore._c_expression import typing
from ..transforms.validators import check_uint32, check_pos_int64
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, check_positive, \
INT32_MAX, check_value
def check_unique_list_of_words(words, arg_name): def check_unique_list_of_words(words, arg_name):
"""Check that words is a list and each element is a str without any duplication""" """Check that words is a list and each element is a str without any duplication"""
if not isinstance(words, list): type_check(words, (list,), arg_name)
raise ValueError(arg_name + " needs to be a list of words of type string.")
words_set = set() words_set = set()
for word in words: for word in words:
if not isinstance(word, str): type_check(word, (str,), arg_name)
raise ValueError("each word in " + arg_name + " needs to be type str.")
if word in words_set: if word in words_set:
raise ValueError(arg_name + " contains duplicate word: " + word + ".") raise ValueError(arg_name + " contains duplicate word: " + word + ".")
words_set.add(word) words_set.add(word)
@ -45,21 +44,14 @@ def check_lookup(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
vocab, unknown = (list(args) + 2 * [None])[:2] [vocab, unknown], _ = parse_user_args(method, *args, **kwargs)
if "vocab" in kwargs:
vocab = kwargs.get("vocab")
if "unknown" in kwargs:
unknown = kwargs.get("unknown")
if unknown is not None: if unknown is not None:
if not (isinstance(unknown, int) and unknown >= 0): type_check(unknown, (int,), "unknown")
raise ValueError("unknown needs to be a non-negative integer.") check_positive(unknown)
type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.")
if not isinstance(vocab, cde.Vocab): return method(self, *args, **kwargs)
raise ValueError("vocab is not an instance of cde.Vocab.")
kwargs["vocab"] = vocab
kwargs["unknown"] = unknown
return method(self, **kwargs)
return new_method return new_method
@ -69,50 +61,15 @@ def check_from_file(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
file_path, delimiter, vocab_size, special_tokens, special_first = (list(args) + 5 * [None])[:5] [file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args,
if "file_path" in kwargs: **kwargs)
file_path = kwargs.get("file_path")
if "delimiter" in kwargs:
delimiter = kwargs.get("delimiter")
if "vocab_size" in kwargs:
vocab_size = kwargs.get("vocab_size")
if "special_tokens" in kwargs:
special_tokens = kwargs.get("special_tokens")
if "special_first" in kwargs:
special_first = kwargs.get("special_first")
if not isinstance(file_path, str):
raise ValueError("file_path needs to be str.")
if delimiter is not None:
if not isinstance(delimiter, str):
raise ValueError("delimiter needs to be str.")
else:
delimiter = ""
if vocab_size is not None:
if not (isinstance(vocab_size, int) and vocab_size > 0):
raise ValueError("vocab size needs to be a positive integer.")
else:
vocab_size = -1
if special_first is None:
special_first = True
if not isinstance(special_first, bool):
raise ValueError("special_first needs to be a boolean value")
if special_tokens is None:
special_tokens = []
check_unique_list_of_words(special_tokens, "special_tokens") check_unique_list_of_words(special_tokens, "special_tokens")
type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"])
if vocab_size is not None:
check_value(vocab_size, (-1, INT32_MAX), "vocab_size")
type_check(special_first, (bool,), special_first)
kwargs["file_path"] = file_path return method(self, *args, **kwargs)
kwargs["delimiter"] = delimiter
kwargs["vocab_size"] = vocab_size
kwargs["special_tokens"] = special_tokens
kwargs["special_first"] = special_first
return method(self, **kwargs)
return new_method return new_method
@ -122,33 +79,20 @@ def check_from_list(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
word_list, special_tokens, special_first = (list(args) + 3 * [None])[:3] [word_list, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs)
if "word_list" in kwargs:
word_list = kwargs.get("word_list")
if "special_tokens" in kwargs:
special_tokens = kwargs.get("special_tokens")
if "special_first" in kwargs:
special_first = kwargs.get("special_first")
if special_tokens is None:
special_tokens = []
word_set = check_unique_list_of_words(word_list, "word_list") word_set = check_unique_list_of_words(word_list, "word_list")
token_set = check_unique_list_of_words(special_tokens, "special_tokens") if special_tokens is not None:
token_set = check_unique_list_of_words(special_tokens, "special_tokens")
intersect = word_set.intersection(token_set) intersect = word_set.intersection(token_set)
if intersect != set(): if intersect != set():
raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".") raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".")
if special_first is None: type_check(special_first, (bool,), "special_first")
special_first = True
if not isinstance(special_first, bool): return method(self, *args, **kwargs)
raise ValueError("special_first needs to be a boolean value.")
kwargs["word_list"] = word_list
kwargs["special_tokens"] = special_tokens
kwargs["special_first"] = special_first
return method(self, **kwargs)
return new_method return new_method
@ -158,18 +102,15 @@ def check_from_dict(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
word_dict, = (list(args) + [None])[:1] [word_dict], _ = parse_user_args(method, *args, **kwargs)
if "word_dict" in kwargs:
word_dict = kwargs.get("word_dict") type_check(word_dict, (dict,), "word_dict")
if not isinstance(word_dict, dict):
raise ValueError("word_dict needs to be a list of word,id pairs.")
for word, word_id in word_dict.items(): for word, word_id in word_dict.items():
if not isinstance(word, str): type_check(word, (str,), "word")
raise ValueError("Each word in word_dict needs to be type string.") type_check(word_id, (int,), "word_id")
if not (isinstance(word_id, int) and word_id >= 0): check_value(word_id, (-1, INT32_MAX), "word_id")
raise ValueError("Each word id needs to be positive integer.") return method(self, *args, **kwargs)
kwargs["word_dict"] = word_dict
return method(self, **kwargs)
return new_method return new_method
@ -179,23 +120,8 @@ def check_jieba_init(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
hmm_path, mp_path, model = (list(args) + 3 * [None])[:3] parse_user_args(method, *args, **kwargs)
return method(self, *args, **kwargs)
if "hmm_path" in kwargs:
hmm_path = kwargs.get("hmm_path")
if "mp_path" in kwargs:
mp_path = kwargs.get("mp_path")
if hmm_path is None:
raise ValueError(
"The dict of HMMSegment in cppjieba is not provided.")
kwargs["hmm_path"] = hmm_path
if mp_path is None:
raise ValueError(
"The dict of MPSegment in cppjieba is not provided.")
kwargs["mp_path"] = mp_path
if model is not None:
kwargs["model"] = model
return method(self, **kwargs)
return new_method return new_method
@ -205,19 +131,12 @@ def check_jieba_add_word(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
word, freq = (list(args) + 2 * [None])[:2] [word, freq], _ = parse_user_args(method, *args, **kwargs)
if "word" in kwargs:
word = kwargs.get("word")
if "freq" in kwargs:
freq = kwargs.get("freq")
if word is None: if word is None:
raise ValueError("word is not provided.") raise ValueError("word is not provided.")
kwargs["word"] = word
if freq is not None: if freq is not None:
check_uint32(freq) check_uint32(freq)
kwargs["freq"] = freq return method(self, *args, **kwargs)
return method(self, **kwargs)
return new_method return new_method
@ -227,13 +146,8 @@ def check_jieba_add_dict(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
user_dict = (list(args) + [None])[0] parse_user_args(method, *args, **kwargs)
if "user_dict" in kwargs: return method(self, *args, **kwargs)
user_dict = kwargs.get("user_dict")
if user_dict is None:
raise ValueError("user_dict is not provided.")
kwargs["user_dict"] = user_dict
return method(self, **kwargs)
return new_method return new_method
@ -244,69 +158,39 @@ def check_from_dataset(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
dataset, columns, freq_range, top_k, special_tokens, special_first = (list(args) + 6 * [None])[:6] [_, columns, freq_range, top_k, special_tokens, special_first], _ = parse_user_args(method, *args,
if "dataset" in kwargs: **kwargs)
dataset = kwargs.get("dataset") if columns is not None:
if "columns" in kwargs: if not isinstance(columns, list):
columns = kwargs.get("columns") columns = [columns]
if "freq_range" in kwargs: col_names = ["col_{0}".format(i) for i in range(len(columns))]
freq_range = kwargs.get("freq_range") type_check_list(columns, (str,), col_names)
if "top_k" in kwargs:
top_k = kwargs.get("top_k")
if "special_tokens" in kwargs:
special_tokens = kwargs.get("special_tokens")
if "special_first" in kwargs:
special_first = kwargs.get("special_first")
if columns is None: if freq_range is not None:
columns = [] type_check(freq_range, (tuple,), "freq_range")
if not isinstance(columns, list): if len(freq_range) != 2:
columns = [columns] raise ValueError("freq_range needs to be a tuple of 2 integers or an int and a None.")
for column in columns: for num in freq_range:
if not isinstance(column, str): if num is not None and (not isinstance(num, int)):
raise ValueError("columns need to be a list of strings.") raise ValueError(
"freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
if freq_range is None: if isinstance(freq_range[0], int) and isinstance(freq_range[1], int):
freq_range = (None, None) if freq_range[0] > freq_range[1] or freq_range[0] < 0:
raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).")
if not isinstance(freq_range, tuple) or len(freq_range) != 2: type_check(top_k, (int, type(None)), "top_k")
raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
for num in freq_range: if isinstance(top_k, int):
if num is not None and (not isinstance(num, int)): check_value(top_k, (0, INT32_MAX), "top_k")
raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.") type_check(special_first, (bool,), "special_first")
if isinstance(freq_range[0], int) and isinstance(freq_range[1], int): if special_tokens is not None:
if freq_range[0] > freq_range[1] or freq_range[0] < 0: check_unique_list_of_words(special_tokens, "special_tokens")
raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).")
if top_k is not None and (not isinstance(top_k, int)): return method(self, *args, **kwargs)
raise ValueError("top_k needs to be a positive integer.")
if isinstance(top_k, int) and top_k <= 0:
raise ValueError("top_k needs to be a positive integer.")
if special_first is None:
special_first = True
if special_tokens is None:
special_tokens = []
if not isinstance(special_first, bool):
raise ValueError("special_first needs to be a boolean value.")
check_unique_list_of_words(special_tokens, "special_tokens")
kwargs["dataset"] = dataset
kwargs["columns"] = columns
kwargs["freq_range"] = freq_range
kwargs["top_k"] = top_k
kwargs["special_tokens"] = special_tokens
kwargs["special_first"] = special_first
return method(self, **kwargs)
return new_method return new_method
@ -316,15 +200,7 @@ def check_ngram(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
n, left_pad, right_pad, separator = (list(args) + 4 * [None])[:4] [n, left_pad, right_pad, separator], _ = parse_user_args(method, *args, **kwargs)
if "n" in kwargs:
n = kwargs.get("n")
if "left_pad" in kwargs:
left_pad = kwargs.get("left_pad")
if "right_pad" in kwargs:
right_pad = kwargs.get("right_pad")
if "separator" in kwargs:
separator = kwargs.get("separator")
if isinstance(n, int): if isinstance(n, int):
n = [n] n = [n]
@ -332,15 +208,9 @@ def check_ngram(method):
if not (isinstance(n, list) and n != []): if not (isinstance(n, list) and n != []):
raise ValueError("n needs to be a non-empty list of positive integers.") raise ValueError("n needs to be a non-empty list of positive integers.")
for gram in n: for i, gram in enumerate(n):
if not (isinstance(gram, int) and gram > 0): type_check(gram, (int,), "gram[{0}]".format(i))
raise ValueError("n in ngram needs to be a positive number.") check_value(gram, (0, INT32_MAX), "gram_{}".format(i))
if left_pad is None:
left_pad = ("", 0)
if right_pad is None:
right_pad = ("", 0)
if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance( if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance(
left_pad[1], int)): left_pad[1], int)):
@ -353,11 +223,7 @@ def check_ngram(method):
if not (left_pad[1] >= 0 and right_pad[1] >= 0): if not (left_pad[1] >= 0 and right_pad[1] >= 0):
raise ValueError("padding width need to be positive numbers.") raise ValueError("padding width need to be positive numbers.")
if separator is None: type_check(separator, (str,), "separator")
separator = " "
if not isinstance(separator, str):
raise ValueError("separator needs to be a string.")
kwargs["n"] = n kwargs["n"] = n
kwargs["left_pad"] = left_pad kwargs["left_pad"] = left_pad
@ -374,16 +240,8 @@ def check_pair_truncate(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
max_length = (list(args) + [None])[0] parse_user_args(method, *args, **kwargs)
if "max_length" in kwargs: return method(self, *args, **kwargs)
max_length = kwargs.get("max_length")
if max_length is None:
raise ValueError("max_length is not provided.")
check_pos_int64(max_length)
kwargs["max_length"] = max_length
return method(self, **kwargs)
return new_method return new_method
@ -393,22 +251,13 @@ def check_to_number(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
data_type = (list(args) + [None])[0] [data_type], _ = parse_user_args(method, *args, **kwargs)
if "data_type" in kwargs: type_check(data_type, (typing.Type,), "data_type")
data_type = kwargs.get("data_type")
if data_type is None:
raise ValueError("data_type is a mandatory parameter but was not provided.")
if not isinstance(data_type, typing.Type):
raise TypeError("data_type is not a MindSpore data type.")
if data_type not in mstype.number_type: if data_type not in mstype.number_type:
raise TypeError("data_type is not numeric data type.") raise TypeError("data_type is not numeric data type.")
kwargs["data_type"] = data_type return method(self, *args, **kwargs)
return method(self, **kwargs)
return new_method return new_method
@ -418,18 +267,11 @@ def check_python_tokenizer(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
tokenizer = (list(args) + [None])[0] [tokenizer], _ = parse_user_args(method, *args, **kwargs)
if "tokenizer" in kwargs:
tokenizer = kwargs.get("tokenizer")
if tokenizer is None:
raise ValueError("tokenizer is a mandatory parameter.")
if not callable(tokenizer): if not callable(tokenizer):
raise TypeError("tokenizer is not a callable python function") raise TypeError("tokenizer is not a callable python function")
kwargs["tokenizer"] = tokenizer return method(self, *args, **kwargs)
return method(self, **kwargs)
return new_method return new_method

View File

@ -18,6 +18,7 @@ from functools import wraps
import numpy as np import numpy as np
from mindspore._c_expression import typing from mindspore._c_expression import typing
from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive
# POS_INT_MIN is used to limit values from starting from 0 # POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1 POS_INT_MIN = 1
@ -37,106 +38,33 @@ DOUBLE_MAX_INTEGER = 9007199254740992
DOUBLE_MIN_INTEGER = -9007199254740992 DOUBLE_MIN_INTEGER = -9007199254740992
def check_type(value, valid_type): def check_fill_value(method):
if not isinstance(value, valid_type): """Wrapper method to check the parameters of fill_value."""
raise ValueError("Wrong input type")
def check_value(value, valid_range):
if value < valid_range[0] or value > valid_range[1]:
raise ValueError("Input is not within the required range")
def check_range(values, valid_range):
if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]:
raise ValueError("Input range is not valid")
def check_positive(value):
if value <= 0:
raise ValueError("Input must greater than 0")
def check_positive_float(value, valid_max=None):
if value <= 0 or not isinstance(value, float) or (valid_max is not None and value > valid_max):
raise ValueError("Input need to be a valid positive float.")
def check_bool(value):
if not isinstance(value, bool):
raise ValueError("Value needs to be a boolean.")
def check_2tuple(value):
if not (isinstance(value, tuple) and len(value) == 2):
raise ValueError("Value needs to be a 2-tuple.")
def check_list(value):
if not isinstance(value, list):
raise ValueError("The input needs to be a list.")
def check_uint8(value):
if not isinstance(value, int):
raise ValueError("The input needs to be a integer")
check_value(value, [UINT8_MIN, UINT8_MAX])
def check_uint32(value):
if not isinstance(value, int):
raise ValueError("The input needs to be a integer")
check_value(value, [UINT32_MIN, UINT32_MAX])
def check_pos_int32(value):
"""Checks for int values starting from 1"""
if not isinstance(value, int):
raise ValueError("The input needs to be a integer")
check_value(value, [POS_INT_MIN, INT32_MAX])
def check_uint64(value):
if not isinstance(value, int):
raise ValueError("The input needs to be a integer")
check_value(value, [UINT64_MIN, UINT64_MAX])
def check_pos_int64(value):
if not isinstance(value, int):
raise ValueError("The input needs to be a integer")
check_value(value, [UINT64_MIN, INT64_MAX])
def check_pos_float32(value):
check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER])
def check_pos_float64(value):
check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER])
def check_one_hot_op(method):
"""Wrapper method to check the parameters of one hot op."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
args = (list(args) + 2 * [None])[:2] [fill_value], _ = parse_user_args(method, *args, **kwargs)
num_classes, smoothing_rate = args type_check(fill_value, (str, float, bool, int, bytes), "fill_value")
if "num_classes" in kwargs:
num_classes = kwargs.get("num_classes") return method(self, *args, **kwargs)
if "smoothing_rate" in kwargs:
smoothing_rate = kwargs.get("smoothing_rate") return new_method
def check_one_hot_op(method):
"""Wrapper method to check the parameters of one_hot_op."""
@wraps(method)
def new_method(self, *args, **kwargs):
[num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs)
type_check(num_classes, (int,), "num_classes")
check_positive(num_classes)
if num_classes is None:
raise ValueError("num_classes")
check_pos_int32(num_classes)
kwargs["num_classes"] = num_classes
if smoothing_rate is not None: if smoothing_rate is not None:
check_value(smoothing_rate, [0., 1.]) check_value(smoothing_rate, [0., 1.], "smoothing_rate")
kwargs["smoothing_rate"] = smoothing_rate
return method(self, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
@ -146,35 +74,12 @@ def check_num_classes(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
num_classes = (list(args) + [None])[0] [num_classes], _ = parse_user_args(method, *args, **kwargs)
if "num_classes" in kwargs:
num_classes = kwargs.get("num_classes")
if num_classes is None:
raise ValueError("num_classes is not provided.")
check_pos_int32(num_classes) type_check(num_classes, (int,), "num_classes")
kwargs["num_classes"] = num_classes check_positive(num_classes)
return method(self, **kwargs) return method(self, *args, **kwargs)
return new_method
def check_fill_value(method):
"""Wrapper method to check the parameters of fill value."""
@wraps(method)
def new_method(self, *args, **kwargs):
fill_value = (list(args) + [None])[0]
if "fill_value" in kwargs:
fill_value = kwargs.get("fill_value")
if fill_value is None:
raise ValueError("fill_value is not provided.")
if not isinstance(fill_value, (str, float, bool, int, bytes)):
raise TypeError("fill_value must be either a primitive python str, float, bool, bytes or int")
kwargs["fill_value"] = fill_value
return method(self, **kwargs)
return new_method return new_method
@ -184,17 +89,11 @@ def check_de_type(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
data_type = (list(args) + [None])[0] [data_type], _ = parse_user_args(method, *args, **kwargs)
if "data_type" in kwargs:
data_type = kwargs.get("data_type")
if data_type is None: type_check(data_type, (typing.Type,), "data_type")
raise ValueError("data_type is not provided.")
if not isinstance(data_type, typing.Type):
raise TypeError("data_type is not a MindSpore data type.")
kwargs["data_type"] = data_type
return method(self, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
@ -204,13 +103,11 @@ def check_slice_op(method):
@wraps(method) @wraps(method)
def new_method(self, *args): def new_method(self, *args):
for i, arg in enumerate(args): for _, arg in enumerate(args):
if arg is not None and arg is not Ellipsis and not isinstance(arg, (int, slice, list)): type_check(arg, (int, slice, list, type(None), type(Ellipsis)), "arg")
raise TypeError("Indexing of dim " + str(i) + "is not of valid type")
if isinstance(arg, list): if isinstance(arg, list):
for a in arg: for a in arg:
if not isinstance(a, int): type_check(a, (int,), "a")
raise TypeError("Index " + a + " is not an int")
return method(self, *args) return method(self, *args)
return new_method return new_method
@ -221,36 +118,14 @@ def check_mask_op(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
operator, constant, dtype = (list(args) + 3 * [None])[:3] [operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs)
if "operator" in kwargs:
operator = kwargs.get("operator")
if "constant" in kwargs:
constant = kwargs.get("constant")
if "dtype" in kwargs:
dtype = kwargs.get("dtype")
if operator is None:
raise ValueError("operator is not provided.")
if constant is None:
raise ValueError("constant is not provided.")
from .c_transforms import Relational from .c_transforms import Relational
if not isinstance(operator, Relational): type_check(operator, (Relational,), "operator")
raise TypeError("operator is not a Relational operator enum.") type_check(constant, (str, float, bool, int, bytes), "constant")
type_check(dtype, (typing.Type,), "dtype")
if not isinstance(constant, (str, float, bool, int, bytes)): return method(self, *args, **kwargs)
raise TypeError("constant must be either a primitive python str, float, bool, bytes or int")
if dtype is not None:
if not isinstance(dtype, typing.Type):
raise TypeError("dtype is not a MindSpore data type.")
kwargs["dtype"] = dtype
kwargs["operator"] = operator
kwargs["constant"] = constant
return method(self, **kwargs)
return new_method return new_method
@ -260,22 +135,12 @@ def check_pad_end(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
pad_shape, pad_value = (list(args) + 2 * [None])[:2]
if "pad_shape" in kwargs:
pad_shape = kwargs.get("pad_shape")
if "pad_value" in kwargs:
pad_value = kwargs.get("pad_value")
if pad_shape is None: [pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs)
raise ValueError("pad_shape is not provided.")
if pad_value is not None: if pad_value is not None:
if not isinstance(pad_value, (str, float, bool, int, bytes)): type_check(pad_value, (str, float, bool, int, bytes), "pad_value")
raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes") type_check(pad_shape, (list,), "pad_end")
kwargs["pad_value"] = pad_value
if not isinstance(pad_shape, list):
raise TypeError("pad_shape must be a list")
for dim in pad_shape: for dim in pad_shape:
if dim is not None: if dim is not None:
@ -284,9 +149,7 @@ def check_pad_end(method):
else: else:
raise TypeError("a value in the list is not an integer.") raise TypeError("a value in the list is not an integer.")
kwargs["pad_shape"] = pad_shape return method(self, *args, **kwargs)
return method(self, **kwargs)
return new_method return new_method
@ -296,31 +159,24 @@ def check_concat_type(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
axis, prepend, append = (list(args) + 3 * [None])[:3]
if "prepend" in kwargs: [axis, prepend, append], _ = parse_user_args(method, *args, **kwargs)
prepend = kwargs.get("prepend")
if "append" in kwargs:
append = kwargs.get("append")
if "axis" in kwargs:
axis = kwargs.get("axis")
if axis is not None: if axis is not None:
if not isinstance(axis, int): type_check(axis, (int,), "axis")
raise TypeError("axis type is not valid, must be an integer.")
if axis not in (0, -1): if axis not in (0, -1):
raise ValueError("only 1D concatenation supported.") raise ValueError("only 1D concatenation supported.")
kwargs["axis"] = axis
if prepend is not None: if prepend is not None:
if not isinstance(prepend, (type(None), np.ndarray)): type_check(prepend, (np.ndarray,), "prepend")
raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.") if len(prepend.shape) != 1:
kwargs["prepend"] = prepend raise ValueError("can only prepend 1D arrays.")
if append is not None: if append is not None:
if not isinstance(append, (type(None), np.ndarray)): type_check(append, (np.ndarray,), "append")
raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.") if len(append.shape) != 1:
kwargs["append"] = append raise ValueError("can only append 1D arrays.")
return method(self, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method

View File

@ -40,12 +40,14 @@ Examples:
>>> dataset = dataset.map(input_columns="image", operations=transforms_list) >>> dataset = dataset.map(input_columns="image", operations=transforms_list)
>>> dataset = dataset.map(input_columns="label", operations=onehot_op) >>> dataset = dataset.map(input_columns="label", operations=onehot_op)
""" """
import numbers
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
from .utils import Inter, Border from .utils import Inter, Border
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \
check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \
FLOAT_MAX_INTEGER
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
@ -57,6 +59,18 @@ DE_C_BORDER_TYPE = {Border.CONSTANT: cde.BorderType.DE_BORDER_CONSTANT,
Border.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC} Border.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC}
def parse_padding(padding):
if isinstance(padding, numbers.Number):
padding = [padding] * 4
if len(padding) == 2:
left = right = padding[0]
top = bottom = padding[1]
padding = (left, top, right, bottom,)
if isinstance(padding, list):
padding = tuple(padding)
return padding
class Decode(cde.DecodeOp): class Decode(cde.DecodeOp):
""" """
Decode the input image in RGB mode. Decode the input image in RGB mode.
@ -136,16 +150,22 @@ class RandomCrop(cde.RandomCropOp):
@check_random_crop @check_random_crop
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
if isinstance(size, int):
size = (size, size)
if padding is None:
padding = (0, 0, 0, 0)
else:
padding = parse_padding(padding)
if isinstance(fill_value, int): # temporary fix
fill_value = tuple([fill_value] * 3)
border_type = DE_C_BORDER_TYPE[padding_mode]
self.size = size self.size = size
self.padding = padding self.padding = padding
self.pad_if_needed = pad_if_needed self.pad_if_needed = pad_if_needed
self.fill_value = fill_value self.fill_value = fill_value
self.padding_mode = padding_mode.value self.padding_mode = padding_mode.value
if padding is None:
padding = (0, 0, 0, 0)
if isinstance(fill_value, int): # temporary fix
fill_value = tuple([fill_value] * 3)
border_type = DE_C_BORDER_TYPE[padding_mode]
super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value) super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value)
@ -184,16 +204,23 @@ class RandomCropWithBBox(cde.RandomCropWithBBoxOp):
@check_random_crop @check_random_crop
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
if isinstance(size, int):
size = (size, size)
if padding is None:
padding = (0, 0, 0, 0)
else:
padding = parse_padding(padding)
if isinstance(fill_value, int): # temporary fix
fill_value = tuple([fill_value] * 3)
border_type = DE_C_BORDER_TYPE[padding_mode]
self.size = size self.size = size
self.padding = padding self.padding = padding
self.pad_if_needed = pad_if_needed self.pad_if_needed = pad_if_needed
self.fill_value = fill_value self.fill_value = fill_value
self.padding_mode = padding_mode.value self.padding_mode = padding_mode.value
if padding is None:
padding = (0, 0, 0, 0)
if isinstance(fill_value, int): # temporary fix
fill_value = tuple([fill_value] * 3)
border_type = DE_C_BORDER_TYPE[padding_mode]
super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value) super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value)
@ -292,6 +319,8 @@ class Resize(cde.ResizeOp):
@check_resize_interpolation @check_resize_interpolation
def __init__(self, size, interpolation=Inter.LINEAR): def __init__(self, size, interpolation=Inter.LINEAR):
if isinstance(size, int):
size = (size, size)
self.size = size self.size = size
self.interpolation = interpolation self.interpolation = interpolation
interpoltn = DE_C_INTER_MODE[interpolation] interpoltn = DE_C_INTER_MODE[interpolation]
@ -359,6 +388,8 @@ class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp):
@check_random_resize_crop @check_random_resize_crop
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation=Inter.BILINEAR, max_attempts=10): interpolation=Inter.BILINEAR, max_attempts=10):
if isinstance(size, int):
size = (size, size)
self.size = size self.size = size
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
@ -396,6 +427,8 @@ class RandomResizedCrop(cde.RandomCropAndResizeOp):
@check_random_resize_crop @check_random_resize_crop
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation=Inter.BILINEAR, max_attempts=10): interpolation=Inter.BILINEAR, max_attempts=10):
if isinstance(size, int):
size = (size, size)
self.size = size self.size = size
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
@ -417,6 +450,8 @@ class CenterCrop(cde.CenterCropOp):
@check_crop @check_crop
def __init__(self, size): def __init__(self, size):
if isinstance(size, int):
size = (size, size)
self.size = size self.size = size
super().__init__(*size) super().__init__(*size)
@ -442,12 +477,26 @@ class RandomColorAdjust(cde.RandomColorAdjustOp):
@check_random_color_adjust @check_random_color_adjust
def __init__(self, brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0)): def __init__(self, brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0)):
brightness = self.expand_values(brightness)
contrast = self.expand_values(contrast)
saturation = self.expand_values(saturation)
hue = self.expand_values(hue, center=0, bound=(-0.5, 0.5), non_negative=False)
self.brightness = brightness self.brightness = brightness
self.contrast = contrast self.contrast = contrast
self.saturation = saturation self.saturation = saturation
self.hue = hue self.hue = hue
super().__init__(*brightness, *contrast, *saturation, *hue) super().__init__(*brightness, *contrast, *saturation, *hue)
def expand_values(self, value, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True):
if isinstance(value, numbers.Number):
value = [center - value, center + value]
if non_negative:
value[0] = max(0, value[0])
check_range(value, bound)
return (value[0], value[1])
class RandomRotation(cde.RandomRotationOp): class RandomRotation(cde.RandomRotationOp):
""" """
@ -485,6 +534,8 @@ class RandomRotation(cde.RandomRotationOp):
self.expand = expand self.expand = expand
self.center = center self.center = center
self.fill_value = fill_value self.fill_value = fill_value
if isinstance(degrees, numbers.Number):
degrees = (-degrees, degrees)
if center is None: if center is None:
center = (-1, -1) center = (-1, -1)
if isinstance(fill_value, int): # temporary fix if isinstance(fill_value, int): # temporary fix
@ -584,6 +635,8 @@ class RandomCropDecodeResize(cde.RandomCropDecodeResizeOp):
@check_random_resize_crop @check_random_resize_crop
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation=Inter.BILINEAR, max_attempts=10): interpolation=Inter.BILINEAR, max_attempts=10):
if isinstance(size, int):
size = (size, size)
self.size = size self.size = size
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
@ -623,12 +676,14 @@ class Pad(cde.PadOp):
@check_pad @check_pad
def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT): def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT):
self.padding = padding padding = parse_padding(padding)
self.fill_value = fill_value
self.padding_mode = padding_mode
if isinstance(fill_value, int): # temporary fix if isinstance(fill_value, int): # temporary fix
fill_value = tuple([fill_value] * 3) fill_value = tuple([fill_value] * 3)
padding_mode = DE_C_BORDER_TYPE[padding_mode] padding_mode = DE_C_BORDER_TYPE[padding_mode]
self.padding = padding
self.fill_value = fill_value
self.padding_mode = padding_mode
super().__init__(*padding, padding_mode, *fill_value) super().__init__(*padding, padding_mode, *fill_value)

View File

@ -28,6 +28,7 @@ import numpy as np
from PIL import Image from PIL import Image
from . import py_transforms_util as util from . import py_transforms_util as util
from .c_transforms import parse_padding
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \ check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_transforms_list, check_random_apply, check_ten_crop, check_num_channels, check_pad, \ check_transforms_list, check_random_apply, check_ten_crop, check_num_channels, check_pad, \
@ -295,6 +296,10 @@ class RandomCrop:
@check_random_crop @check_random_crop
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
if padding is None:
padding = (0, 0, 0, 0)
else:
padding = parse_padding(padding)
self.size = size self.size = size
self.padding = padding self.padding = padding
self.pad_if_needed = pad_if_needed self.pad_if_needed = pad_if_needed
@ -753,6 +758,8 @@ class TenCrop:
@check_ten_crop @check_ten_crop
def __init__(self, size, use_vertical_flip=False): def __init__(self, size, use_vertical_flip=False):
if isinstance(size, int):
size = (size, size)
self.size = size self.size = size
self.use_vertical_flip = use_vertical_flip self.use_vertical_flip = use_vertical_flip
@ -877,6 +884,8 @@ class Pad:
@check_pad @check_pad
def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT): def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT):
parse_padding(padding)
self.padding = padding self.padding = padding
self.fill_value = fill_value self.fill_value = fill_value
self.padding_mode = DE_PY_BORDER_TYPE[padding_mode] self.padding_mode = DE_PY_BORDER_TYPE[padding_mode]
@ -1129,56 +1138,23 @@ class RandomAffine:
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Inter.NEAREST, fill_value=0): def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Inter.NEAREST, fill_value=0):
# Parameter checking # Parameter checking
# rotation # rotation
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
self.degrees = degrees
else:
raise TypeError("If degrees is a list or tuple, it must be of length 2.")
# translation
if translate is not None:
if isinstance(translate, (tuple, list)) and len(translate) == 2:
for t in translate:
if t < 0.0 or t > 1.0:
raise ValueError("translation values should be between 0 and 1")
else:
raise TypeError("translate should be a list or tuple of length 2.")
self.translate = translate
# scale
if scale is not None:
if isinstance(scale, (tuple, list)) and len(scale) == 2:
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
else:
raise TypeError("scale should be a list or tuple of length 2.")
self.scale_ranges = scale
# shear
if shear is not None: if shear is not None:
if isinstance(shear, numbers.Number): if isinstance(shear, numbers.Number):
if shear < 0: shear = (-1 * shear, shear)
raise ValueError("If shear is a single number, it must be positive.")
self.shear = (-1 * shear, shear)
elif isinstance(shear, (tuple, list)) and (len(shear) == 2 or len(shear) == 4):
# X-Axis shear with [min, max]
if len(shear) == 2:
self.shear = [shear[0], shear[1], 0., 0.]
elif len(shear) == 4:
self.shear = [s for s in shear]
else: else:
raise TypeError("shear should be a list or tuple and it must be of length 2 or 4.") if len(shear) == 2:
else: shear = [shear[0], shear[1], 0., 0.]
self.shear = shear elif len(shear) == 4:
shear = [s for s in shear]
# resample if isinstance(degrees, numbers.Number):
degrees = (-degrees, degrees)
self.degrees = degrees
self.translate = translate
self.scale_ranges = scale
self.shear = shear
self.resample = DE_PY_INTER_MODE[resample] self.resample = DE_PY_INTER_MODE[resample]
# fill_value
self.fill_value = fill_value self.fill_value = fill_value
def __call__(self, img): def __call__(self, img):

File diff suppressed because it is too large Load Diff

View File

@ -15,13 +15,15 @@
""" """
Testing the bounding box augment op in DE Testing the bounding box augment op in DE
""" """
from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \
config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5
import numpy as np import numpy as np
import mindspore.log as logger import mindspore.log as logger
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision import mindspore.dataset.transforms.vision.c_transforms as c_vision
from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \
config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5
GENERATE_GOLDEN = False GENERATE_GOLDEN = False
# updated VOC dataset with correct annotations # updated VOC dataset with correct annotations
@ -241,7 +243,7 @@ def test_bounding_box_augment_invalid_ratio_c():
operations=[test_op]) # Add column for "annotation" operations=[test_op]) # Add column for "annotation"
except ValueError as error: except ValueError as error:
logger.info("Got an exception in DE: {}".format(str(error))) logger.info("Got an exception in DE: {}".format(str(error)))
assert "Input is not" in str(error) assert "Input ratio is not within the required interval of (0.0 to 1.0)." in str(error)
def test_bounding_box_augment_invalid_bounds_c(): def test_bounding_box_augment_invalid_bounds_c():

View File

@ -17,6 +17,7 @@ import pytest
import numpy as np import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
# generates 1 column [0], [0, 1], ..., [0, ..., n-1] # generates 1 column [0], [0, 1], ..., [0, ..., n-1]
def generate_sequential(n): def generate_sequential(n):
for i in range(n): for i in range(n):
@ -99,12 +100,12 @@ def test_bucket_batch_invalid_input():
with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
None, None, invalid_type_pad_to_bucket_boundary) None, None, invalid_type_pad_to_bucket_boundary)
assert "Wrong input type for pad_to_bucket_boundary, should be <class 'bool'>" in str(info.value) assert "Argument pad_to_bucket_boundary with value \"\" is not of type (<class \'bool\'>,)." in str(info.value)
with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
None, None, False, invalid_type_drop_remainder) None, None, False, invalid_type_drop_remainder)
assert "Wrong input type for drop_remainder, should be <class 'bool'>" in str(info.value) assert "Argument drop_remainder with value \"\" is not of type (<class 'bool'>,)." in str(info.value)
def test_bucket_batch_multi_bucket_no_padding(): def test_bucket_batch_multi_bucket_no_padding():
@ -272,7 +273,6 @@ def test_bucket_batch_default_pad():
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]]
output = [] output = []
for data in dataset.create_dict_iterator(): for data in dataset.create_dict_iterator():
output.append(data["col1"].tolist()) output.append(data["col1"].tolist())

View File

@ -163,18 +163,11 @@ def test_concatenate_op_negative_axis():
def test_concatenate_op_incorrect_input_dim(): def test_concatenate_op_incorrect_input_dim():
def gen():
yield (np.array(["ss", "ad"], dtype='S'),)
prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S') prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S')
data = ds.GeneratorDataset(gen, column_names=["col"])
concatenate_op = data_trans.Concatenate(0, prepend_tensor)
data = data.map(input_columns=["col"], operations=concatenate_op) with pytest.raises(ValueError) as error_info:
with pytest.raises(RuntimeError) as error_info: data_trans.Concatenate(0, prepend_tensor)
for _ in data: assert "can only prepend 1D arrays." in repr(error_info.value)
pass
assert "Only 1D tensors supported" in repr(error_info.value)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -28,9 +28,9 @@ def test_exception_01():
""" """
logger.info("test_exception_01") logger.info("test_exception_01")
data = ds.TFRecordDataset(DATA_DIR, columns_list=["image"]) data = ds.TFRecordDataset(DATA_DIR, columns_list=["image"])
with pytest.raises(ValueError) as info: with pytest.raises(TypeError) as info:
data = data.map(input_columns=["image"], operations=vision.Resize(100, 100)) data.map(input_columns=["image"], operations=vision.Resize(100, 100))
assert "Invalid interpolation mode." in str(info.value) assert "Argument interpolation with value 100 is not of type (<enum 'Inter'>,)" in str(info.value)
def test_exception_02(): def test_exception_02():
@ -40,8 +40,8 @@ def test_exception_02():
logger.info("test_exception_02") logger.info("test_exception_02")
num_samples = -1 num_samples = -1
with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
assert "num_samples cannot be less than 0" in str(info.value) assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(info.value)
num_samples = 1 num_samples = 1
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)

View File

@ -23,7 +23,8 @@ import mindspore.dataset.text as text
def test_demo_basic_from_dataset(): def test_demo_basic_from_dataset():
""" this is a tutorial on how from_dataset should be used in a normal use case""" """ this is a tutorial on how from_dataset should be used in a normal use case"""
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"], vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None,
special_tokens=["<pad>", "<unk>"],
special_first=True) special_first=True)
data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
res = [] res = []
@ -127,15 +128,16 @@ def test_from_dataset_exceptions():
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k) vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k)
assert isinstance(vocab.text.Vocab) assert isinstance(vocab.text.Vocab)
except ValueError as e: except (TypeError, ValueError, RuntimeError) as e:
assert s in str(e), str(e) assert s in str(e), str(e)
test_config("text", (), 1, "freq_range needs to be either None or a tuple of 2 integers") test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.")
test_config("text", (2, 3), 1.2345, "top_k needs to be a positive integer") test_config("text", (2, 3), 1.2345,
test_config(23, (2, 3), 1.2345, "columns need to be a list of strings") "Argument top_k with value 1.2345 is not of type (<class 'int'>, <class 'NoneType'>)")
test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b") test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (<class 'str'>,)")
test_config("text", (2, 3), 0, "top_k needs to be a positive integer") test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)")
test_config([123], (2, 3), 0, "columns need to be a list of strings") test_config("text", (2, 3), 0, "top_k needs to be positive number")
test_config([123], (2, 3), 0, "top_k needs to be positive number")
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -73,6 +73,7 @@ def test_linear_transformation_op(plot=False):
if plot: if plot:
visualize_list(image, image_transformed) visualize_list(image, image_transformed)
def test_linear_transformation_md5(): def test_linear_transformation_md5():
""" """
Test LinearTransformation op: valid params (transformation_matrix, mean_vector) Test LinearTransformation op: valid params (transformation_matrix, mean_vector)
@ -102,6 +103,7 @@ def test_linear_transformation_md5():
filename = "linear_transformation_01_result.npz" filename = "linear_transformation_01_result.npz"
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_linear_transformation_exception_01(): def test_linear_transformation_exception_01():
""" """
Test LinearTransformation op: transformation_matrix is not provided Test LinearTransformation op: transformation_matrix is not provided
@ -126,9 +128,10 @@ def test_linear_transformation_exception_01():
] ]
transform = py_vision.ComposeOp(transforms) transform = py_vision.ComposeOp(transforms)
data1 = data1.map(input_columns=["image"], operations=transform()) data1 = data1.map(input_columns=["image"], operations=transform())
except ValueError as e: except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "not provided" in str(e) assert "Argument transformation_matrix with value None is not of type (<class 'numpy.ndarray'>,)" in str(e)
def test_linear_transformation_exception_02(): def test_linear_transformation_exception_02():
""" """
@ -154,9 +157,10 @@ def test_linear_transformation_exception_02():
] ]
transform = py_vision.ComposeOp(transforms) transform = py_vision.ComposeOp(transforms)
data1 = data1.map(input_columns=["image"], operations=transform()) data1 = data1.map(input_columns=["image"], operations=transform())
except ValueError as e: except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "not provided" in str(e) assert "Argument mean_vector with value None is not of type (<class 'numpy.ndarray'>,)" in str(e)
def test_linear_transformation_exception_03(): def test_linear_transformation_exception_03():
""" """
@ -187,6 +191,7 @@ def test_linear_transformation_exception_03():
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "square matrix" in str(e) assert "square matrix" in str(e)
def test_linear_transformation_exception_04(): def test_linear_transformation_exception_04():
""" """
Test LinearTransformation op: mean_vector does not match dimension of transformation_matrix Test LinearTransformation op: mean_vector does not match dimension of transformation_matrix
@ -199,7 +204,7 @@ def test_linear_transformation_exception_04():
weight = 50 weight = 50
dim = 3 * height * weight dim = 3 * height * weight
transformation_matrix = np.ones([dim, dim]) transformation_matrix = np.ones([dim, dim])
mean_vector = np.zeros(dim-1) mean_vector = np.zeros(dim - 1)
# Generate dataset # Generate dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
@ -216,6 +221,7 @@ def test_linear_transformation_exception_04():
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "should match" in str(e) assert "should match" in str(e)
if __name__ == '__main__': if __name__ == '__main__':
test_linear_transformation_op(plot=True) test_linear_transformation_op(plot=True)
test_linear_transformation_md5() test_linear_transformation_md5()

View File

@ -184,24 +184,26 @@ def test_minddataset_invalidate_num_shards():
create_cv_mindrecord(1) create_cv_mindrecord(1)
columns_list = ["data", "label"] columns_list = ["data", "label"]
num_readers = 4 num_readers = 4
with pytest.raises(Exception, match="shard_id is invalid, "): with pytest.raises(Exception) as error_info:
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2) data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2)
num_iter = 0 num_iter = 0
for _ in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info)
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
def test_minddataset_invalidate_shard_id(): def test_minddataset_invalidate_shard_id():
create_cv_mindrecord(1) create_cv_mindrecord(1)
columns_list = ["data", "label"] columns_list = ["data", "label"]
num_readers = 4 num_readers = 4
with pytest.raises(Exception, match="shard_id is invalid, "): with pytest.raises(Exception) as error_info:
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1) data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1)
num_iter = 0 num_iter = 0
for _ in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info)
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
@ -210,17 +212,19 @@ def test_minddataset_shard_id_bigger_than_num_shard():
create_cv_mindrecord(1) create_cv_mindrecord(1)
columns_list = ["data", "label"] columns_list = ["data", "label"]
num_readers = 4 num_readers = 4
with pytest.raises(Exception, match="shard_id is invalid, "): with pytest.raises(Exception) as error_info:
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2) data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2)
num_iter = 0 num_iter = 0
for _ in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info)
with pytest.raises(Exception, match="shard_id is invalid, "): with pytest.raises(Exception) as error_info:
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
num_iter = 0 num_iter = 0
for _ in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info)
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))

View File

@ -15,9 +15,9 @@
""" """
Testing Ngram in mindspore.dataset Testing Ngram in mindspore.dataset
""" """
import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.text as text import mindspore.dataset.text as text
import numpy as np
def test_multiple_ngrams(): def test_multiple_ngrams():
@ -61,7 +61,7 @@ def test_simple_ngram():
yield (np.array(line.split(" "), dtype='S'),) yield (np.array(line.split(" "), dtype='S'),)
dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"]) dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"])
dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=None)) dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=" "))
i = 0 i = 0
for data in dataset.create_dict_iterator(): for data in dataset.create_dict_iterator():
@ -72,7 +72,7 @@ def test_simple_ngram():
def test_corner_cases(): def test_corner_cases():
""" testing various corner cases and exceptions""" """ testing various corner cases and exceptions"""
def test_config(input_line, output_line, n, l_pad=None, r_pad=None, sep=None): def test_config(input_line, output_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "):
def gen(texts): def gen(texts):
yield (np.array(texts.split(" "), dtype='S'),) yield (np.array(texts.split(" "), dtype='S'),)
@ -93,7 +93,7 @@ def test_corner_cases():
try: try:
test_config("Yours to Discover", "", [0, [1]]) test_config("Yours to Discover", "", [0, [1]])
except Exception as e: except Exception as e:
assert "ngram needs to be a positive number" in str(e) assert "Argument gram[1] with value [1] is not of type (<class 'int'>,)" in str(e)
# test empty n # test empty n
try: try:
test_config("Yours to Discover", "", []) test_config("Yours to Discover", "", [])

View File

@ -279,7 +279,7 @@ def test_normalize_exception_invalid_range_py():
_ = py_vision.Normalize([0.75, 1.25, 0.5], [0.1, 0.18, 1.32]) _ = py_vision.Normalize([0.75, 1.25, 0.5], [0.1, 0.18, 1.32])
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not within the required range" in str(e) assert "Input mean_value is not within the required interval of (0.0 to 1.0)." in str(e)
def test_normalize_grayscale_md5_01(): def test_normalize_grayscale_md5_01():

View File

@ -61,6 +61,10 @@ def test_pad_end_exceptions():
pad_compare([3, 4, 5], ["2"], 1, []) pad_compare([3, 4, 5], ["2"], 1, [])
assert "a value in the list is not an integer." in str(info.value) assert "a value in the list is not an integer." in str(info.value)
with pytest.raises(TypeError) as info:
pad_compare([1, 2], 3, -1, [1, 2, -1])
assert "Argument pad_end with value 3 is not of type (<class 'list'>,)" in str(info.value)
if __name__ == "__main__": if __name__ == "__main__":
test_pad_end_basics() test_pad_end_basics()

View File

@ -103,7 +103,7 @@ def test_random_affine_exception_negative_degrees():
_ = py_vision.RandomAffine(degrees=-15) _ = py_vision.RandomAffine(degrees=-15)
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "If degrees is a single number, it cannot be negative." assert str(e) == "Input degrees is not within the required interval of (0 to inf)."
def test_random_affine_exception_translation_range(): def test_random_affine_exception_translation_range():
@ -115,7 +115,7 @@ def test_random_affine_exception_translation_range():
_ = py_vision.RandomAffine(degrees=15, translate=(0.1, 1.5)) _ = py_vision.RandomAffine(degrees=15, translate=(0.1, 1.5))
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "translation values should be between 0 and 1" assert str(e) == "Input translate at 1 is not within the required interval of (0.0 to 1.0)."
def test_random_affine_exception_scale_value(): def test_random_affine_exception_scale_value():
@ -127,7 +127,7 @@ def test_random_affine_exception_scale_value():
_ = py_vision.RandomAffine(degrees=15, scale=(0.0, 1.1)) _ = py_vision.RandomAffine(degrees=15, scale=(0.0, 1.1))
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "scale values should be positive" assert str(e) == "Input scale[0] must be greater than 0."
def test_random_affine_exception_shear_value(): def test_random_affine_exception_shear_value():
@ -139,7 +139,7 @@ def test_random_affine_exception_shear_value():
_ = py_vision.RandomAffine(degrees=15, shear=-5) _ = py_vision.RandomAffine(degrees=15, shear=-5)
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "If shear is a single number, it must be positive." assert str(e) == "Input shear must be greater than 0."
def test_random_affine_exception_degrees_size(): def test_random_affine_exception_degrees_size():
@ -165,7 +165,9 @@ def test_random_affine_exception_translate_size():
_ = py_vision.RandomAffine(degrees=15, translate=(0.1)) _ = py_vision.RandomAffine(degrees=15, translate=(0.1))
except TypeError as e: except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "translate should be a list or tuple of length 2." assert str(
e) == "Argument translate with value 0.1 is not of type (<class 'list'>," \
" <class 'tuple'>)."
def test_random_affine_exception_scale_size(): def test_random_affine_exception_scale_size():
@ -178,7 +180,8 @@ def test_random_affine_exception_scale_size():
_ = py_vision.RandomAffine(degrees=15, scale=(0.5)) _ = py_vision.RandomAffine(degrees=15, scale=(0.5))
except TypeError as e: except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "scale should be a list or tuple of length 2." assert str(e) == "Argument scale with value 0.5 is not of type (<class 'tuple'>," \
" <class 'list'>)."
def test_random_affine_exception_shear_size(): def test_random_affine_exception_shear_size():
@ -191,7 +194,7 @@ def test_random_affine_exception_shear_size():
_ = py_vision.RandomAffine(degrees=15, shear=(-5, 5, 10)) _ = py_vision.RandomAffine(degrees=15, shear=(-5, 5, 10))
except TypeError as e: except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "shear should be a list or tuple and it must be of length 2 or 4." assert str(e) == "shear must be of length 2 or 4."
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -97,7 +97,7 @@ def test_random_color_md5():
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms = F.ComposeOp([F.Decode(), transforms = F.ComposeOp([F.Decode(),
F.RandomColor((0.5, 1.5)), F.RandomColor((0.1, 1.9)),
F.ToTensor()]) F.ToTensor()])
data = data.map(input_columns="image", operations=transforms()) data = data.map(input_columns="image", operations=transforms())

View File

@ -232,7 +232,7 @@ def test_random_crop_and_resize_04_c():
data = data.map(input_columns=["image"], operations=random_crop_and_resize_op) data = data.map(input_columns=["image"], operations=random_crop_and_resize_op)
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input range is not valid" in str(e) assert "Input is not within the required interval of (0 to 16777216)." in str(e)
def test_random_crop_and_resize_04_py(): def test_random_crop_and_resize_04_py():
@ -255,7 +255,7 @@ def test_random_crop_and_resize_04_py():
data = data.map(input_columns=["image"], operations=transform()) data = data.map(input_columns=["image"], operations=transform())
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input range is not valid" in str(e) assert "Input is not within the required interval of (0 to 16777216)." in str(e)
def test_random_crop_and_resize_05_c(): def test_random_crop_and_resize_05_c():
@ -275,7 +275,7 @@ def test_random_crop_and_resize_05_c():
data = data.map(input_columns=["image"], operations=random_crop_and_resize_op) data = data.map(input_columns=["image"], operations=random_crop_and_resize_op)
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input range is not valid" in str(e) assert "Input is not within the required interval of (0 to 16777216)." in str(e)
def test_random_crop_and_resize_05_py(): def test_random_crop_and_resize_05_py():
@ -298,7 +298,7 @@ def test_random_crop_and_resize_05_py():
data = data.map(input_columns=["image"], operations=transform()) data = data.map(input_columns=["image"], operations=transform())
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input range is not valid" in str(e) assert "Input is not within the required interval of (0 to 16777216)." in str(e)
def test_random_crop_and_resize_comp(plot=False): def test_random_crop_and_resize_comp(plot=False):

View File

@ -159,7 +159,7 @@ def test_random_resized_crop_with_bbox_op_invalid_c():
except ValueError as err: except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err))) logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input range is not valid" in str(err) assert "Input is not within the required interval of (0 to 16777216)." in str(err)
def test_random_resized_crop_with_bbox_op_invalid2_c(): def test_random_resized_crop_with_bbox_op_invalid2_c():
@ -185,7 +185,7 @@ def test_random_resized_crop_with_bbox_op_invalid2_c():
except ValueError as err: except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err))) logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input range is not valid" in str(err) assert "Input is not within the required interval of (0 to 16777216)." in str(err)
def test_random_resized_crop_with_bbox_op_bad_c(): def test_random_resized_crop_with_bbox_op_bad_c():

View File

@ -179,7 +179,7 @@ def test_random_grayscale_invalid_param():
data = data.map(input_columns=["image"], operations=transform()) data = data.map(input_columns=["image"], operations=transform())
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not within the required range" in str(e) assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
if __name__ == "__main__": if __name__ == "__main__":
test_random_grayscale_valid_prob(True) test_random_grayscale_valid_prob(True)

View File

@ -141,7 +141,7 @@ def test_random_horizontal_invalid_prob_c():
data = data.map(input_columns=["image"], operations=random_horizontal_op) data = data.map(input_columns=["image"], operations=random_horizontal_op)
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not" in str(e) assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
def test_random_horizontal_invalid_prob_py(): def test_random_horizontal_invalid_prob_py():
@ -164,7 +164,7 @@ def test_random_horizontal_invalid_prob_py():
data = data.map(input_columns=["image"], operations=transform()) data = data.map(input_columns=["image"], operations=transform())
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not" in str(e) assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
def test_random_horizontal_comp(plot=False): def test_random_horizontal_comp(plot=False):

View File

@ -190,7 +190,7 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c():
operations=[test_op]) # Add column for "annotation" operations=[test_op]) # Add column for "annotation"
except ValueError as error: except ValueError as error:
logger.info("Got an exception in DE: {}".format(str(error))) logger.info("Got an exception in DE: {}".format(str(error)))
assert "Input is not" in str(error) assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(error)
def test_random_horizontal_flip_with_bbox_invalid_bounds_c(): def test_random_horizontal_flip_with_bbox_invalid_bounds_c():

View File

@ -107,7 +107,7 @@ def test_random_perspective_exception_distortion_scale_range():
_ = py_vision.RandomPerspective(distortion_scale=1.5) _ = py_vision.RandomPerspective(distortion_scale=1.5)
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input is not within the required range" assert str(e) == "Input distortion_scale is not within the required interval of (0.0 to 1.0)."
def test_random_perspective_exception_prob_range(): def test_random_perspective_exception_prob_range():
@ -119,7 +119,7 @@ def test_random_perspective_exception_prob_range():
_ = py_vision.RandomPerspective(prob=1.2) _ = py_vision.RandomPerspective(prob=1.2)
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input is not within the required range" assert str(e) == "Input prob is not within the required interval of (0.0 to 1.0)."
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -163,7 +163,7 @@ def test_random_resize_with_bbox_op_invalid_c():
except ValueError as err: except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err))) logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input is not" in str(err) assert "Input is not within the required interval of (1 to 16777216)." in str(err)
try: try:
# one of the size values is zero # one of the size values is zero
@ -171,7 +171,7 @@ def test_random_resize_with_bbox_op_invalid_c():
except ValueError as err: except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err))) logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input is not" in str(err) assert "Input size at dim 0 is not within the required interval of (1 to 2147483647)." in str(err)
try: try:
# negative value for resize # negative value for resize
@ -179,7 +179,7 @@ def test_random_resize_with_bbox_op_invalid_c():
except ValueError as err: except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err))) logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input is not" in str(err) assert "Input is not within the required interval of (1 to 16777216)." in str(err)
try: try:
# invalid input shape # invalid input shape

View File

@ -97,7 +97,7 @@ def test_random_sharpness_md5():
# define map operations # define map operations
transforms = [ transforms = [
F.Decode(), F.Decode(),
F.RandomSharpness((0.5, 1.5)), F.RandomSharpness((0.1, 1.9)),
F.ToTensor() F.ToTensor()
] ]
transform = F.ComposeOp(transforms) transform = F.ComposeOp(transforms)

View File

@ -141,7 +141,7 @@ def test_random_vertical_invalid_prob_c():
data = data.map(input_columns=["image"], operations=random_horizontal_op) data = data.map(input_columns=["image"], operations=random_horizontal_op)
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not" in str(e) assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e)
def test_random_vertical_invalid_prob_py(): def test_random_vertical_invalid_prob_py():
@ -163,7 +163,7 @@ def test_random_vertical_invalid_prob_py():
data = data.map(input_columns=["image"], operations=transform()) data = data.map(input_columns=["image"], operations=transform())
except ValueError as e: except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not" in str(e) assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e)
def test_random_vertical_comp(plot=False): def test_random_vertical_comp(plot=False):

View File

@ -191,7 +191,7 @@ def test_random_vertical_flip_with_bbox_op_invalid_c():
except ValueError as err: except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err))) logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input is not" in str(err) assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(err)
def test_random_vertical_flip_with_bbox_op_bad_c(): def test_random_vertical_flip_with_bbox_op_bad_c():

View File

@ -150,7 +150,7 @@ def test_resize_with_bbox_op_invalid_c():
# invalid interpolation value # invalid interpolation value
c_vision.ResizeWithBBox(400, interpolation="invalid") c_vision.ResizeWithBBox(400, interpolation="invalid")
except ValueError as err: except TypeError as err:
logger.info("Got an exception in DE: {}".format(str(err))) logger.info("Got an exception in DE: {}".format(str(err)))
assert "interpolation" in str(err) assert "interpolation" in str(err)

View File

@ -154,7 +154,7 @@ def test_shuffle_exception_01():
except Exception as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "buffer_size" in str(e) assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
def test_shuffle_exception_02(): def test_shuffle_exception_02():
@ -172,7 +172,7 @@ def test_shuffle_exception_02():
except Exception as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "buffer_size" in str(e) assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
def test_shuffle_exception_03(): def test_shuffle_exception_03():
@ -190,7 +190,7 @@ def test_shuffle_exception_03():
except Exception as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "buffer_size" in str(e) assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
def test_shuffle_exception_05(): def test_shuffle_exception_05():

View File

@ -62,7 +62,7 @@ def util_test_ten_crop(crop_size, vertical_flip=False, plot=False):
logger.info("dtype of image_2: {}".format(image_2.dtype)) logger.info("dtype of image_2: {}".format(image_2.dtype))
if plot: if plot:
visualize_list(np.array([image_1]*10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1)) visualize_list(np.array([image_1] * 10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1))
# The output data should be of a 4D tensor shape, a stack of 10 images. # The output data should be of a 4D tensor shape, a stack of 10 images.
assert len(image_2.shape) == 4 assert len(image_2.shape) == 4
@ -144,7 +144,7 @@ def test_ten_crop_invalid_size_error_msg():
vision.TenCrop(0), vision.TenCrop(0),
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
] ]
error_msg = "Input is not within the required range" error_msg = "Input is not within the required interval of (1 to 16777216)."
assert error_msg == str(info.value) assert error_msg == str(info.value)
with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:

View File

@ -169,7 +169,9 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2):
except Exception as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "operations" in str(e) assert "Argument tensor_op_5 with value" \
" <mindspore.dataset.transforms.vision.py_transforms.Invert" in str(e)
assert "is not of type (<class 'mindspore._c_dataengine.TensorOp'>,)" in str(e)
def test_cpp_uniform_augment_exception_large_numops(num_ops=6): def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
@ -209,7 +211,7 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
except Exception as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_ops" in str(e) assert "Input num_ops must be greater than 0" in str(e)
def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5): def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
@ -229,7 +231,7 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
except Exception as e: except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "integer" in str(e) assert "Argument num_ops with value 2.5 is not of type (<class 'int'>,)" in str(e)
def test_cpp_uniform_augment_random_crop_badinput(num_ops=1): def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):

View File

@ -314,14 +314,15 @@ def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=
if len(orig) != len(aug) or not orig: if len(orig) != len(aug) or not orig:
return return
batch_size = int(len(orig)/plot_rows) # creates batches of images to plot together batch_size = int(len(orig) / plot_rows) # creates batches of images to plot together
split_point = batch_size * plot_rows split_point = batch_size * plot_rows
orig, aug = np.array(orig), np.array(aug) orig, aug = np.array(orig), np.array(aug)
if len(orig) > plot_rows: if len(orig) > plot_rows:
# Create batches of required size and add remainder to last batch # Create batches of required size and add remainder to last batch
orig = np.split(orig[:split_point], batch_size) + ([orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added orig = np.split(orig[:split_point], batch_size) + (
[orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added
aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else []) aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else [])
else: else:
orig = [orig] orig = [orig]
@ -336,7 +337,8 @@ def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=
for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])):
cur_ix = base_ix + x cur_ix = base_ix + x
(axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) # select plotting axes based on number of image rows on plot - else case when 1 row # select plotting axes based on number of image rows on plot - else case when 1 row
(axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1])
axA.imshow(dataA["image"]) axA.imshow(dataA["image"])
add_bounding_boxes(axA, dataA[annot_name]) add_bounding_boxes(axA, dataA[annot_name])