fix validators
fixed random_apply tests fix validators fixed random_apply tests fix engine validation
This commit is contained in:
parent
915ddd25dd
commit
6c37ea3be0
|
@ -0,0 +1,342 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
General Validators.
|
||||
"""
|
||||
import inspect
|
||||
from multiprocessing import cpu_count
|
||||
import os
|
||||
import numpy as np
|
||||
from ..engine import samplers
|
||||
|
||||
# POS_INT_MIN is used to limit values from starting from 0
|
||||
POS_INT_MIN = 1
|
||||
UINT8_MAX = 255
|
||||
UINT8_MIN = 0
|
||||
UINT32_MAX = 4294967295
|
||||
UINT32_MIN = 0
|
||||
UINT64_MAX = 18446744073709551615
|
||||
UINT64_MIN = 0
|
||||
INT32_MAX = 2147483647
|
||||
INT32_MIN = -2147483648
|
||||
INT64_MAX = 9223372036854775807
|
||||
INT64_MIN = -9223372036854775808
|
||||
FLOAT_MAX_INTEGER = 16777216
|
||||
FLOAT_MIN_INTEGER = -16777216
|
||||
DOUBLE_MAX_INTEGER = 9007199254740992
|
||||
DOUBLE_MIN_INTEGER = -9007199254740992
|
||||
|
||||
valid_detype = [
|
||||
"bool", "int8", "int16", "int32", "int64", "uint8", "uint16",
|
||||
"uint32", "uint64", "float16", "float32", "float64", "string"
|
||||
]
|
||||
|
||||
|
||||
def pad_arg_name(arg_name):
|
||||
if arg_name != "":
|
||||
arg_name = arg_name + " "
|
||||
return arg_name
|
||||
|
||||
|
||||
def check_value(value, valid_range, arg_name=""):
|
||||
arg_name = pad_arg_name(arg_name)
|
||||
if value < valid_range[0] or value > valid_range[1]:
|
||||
raise ValueError(
|
||||
"Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0],
|
||||
valid_range[1]))
|
||||
|
||||
|
||||
def check_range(values, valid_range, arg_name=""):
|
||||
arg_name = pad_arg_name(arg_name)
|
||||
if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]:
|
||||
raise ValueError(
|
||||
"Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0],
|
||||
valid_range[1]))
|
||||
|
||||
|
||||
def check_positive(value, arg_name=""):
|
||||
arg_name = pad_arg_name(arg_name)
|
||||
if value <= 0:
|
||||
raise ValueError("Input {0}must be greater than 0.".format(arg_name))
|
||||
|
||||
|
||||
def check_positive_float(value, arg_name=""):
|
||||
arg_name = pad_arg_name(arg_name)
|
||||
type_check(value, (float,), arg_name)
|
||||
check_positive(value, arg_name)
|
||||
|
||||
|
||||
def check_2tuple(value, arg_name=""):
|
||||
if not (isinstance(value, tuple) and len(value) == 2):
|
||||
raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name))
|
||||
|
||||
|
||||
def check_uint8(value, arg_name=""):
|
||||
type_check(value, (int,), arg_name)
|
||||
check_value(value, [UINT8_MIN, UINT8_MAX])
|
||||
|
||||
|
||||
def check_uint32(value, arg_name=""):
|
||||
type_check(value, (int,), arg_name)
|
||||
check_value(value, [UINT32_MIN, UINT32_MAX])
|
||||
|
||||
|
||||
def check_pos_int32(value, arg_name=""):
|
||||
type_check(value, (int,), arg_name)
|
||||
check_value(value, [POS_INT_MIN, INT32_MAX])
|
||||
|
||||
|
||||
def check_uint64(value, arg_name=""):
|
||||
type_check(value, (int,), arg_name)
|
||||
check_value(value, [UINT64_MIN, UINT64_MAX])
|
||||
|
||||
|
||||
def check_pos_int64(value, arg_name=""):
|
||||
type_check(value, (int,), arg_name)
|
||||
check_value(value, [UINT64_MIN, INT64_MAX])
|
||||
|
||||
|
||||
def check_pos_float32(value, arg_name=""):
|
||||
check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name)
|
||||
|
||||
|
||||
def check_pos_float64(value, arg_name=""):
|
||||
check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER], arg_name)
|
||||
|
||||
|
||||
def check_valid_detype(type_):
|
||||
if type_ not in valid_detype:
|
||||
raise ValueError("Unknown column type")
|
||||
return True
|
||||
|
||||
|
||||
def check_columns(columns, name):
|
||||
type_check(columns, (list, str), name)
|
||||
if isinstance(columns, list):
|
||||
if not columns:
|
||||
raise ValueError("Column names should not be empty")
|
||||
col_names = ["col_{0}".format(i) for i in range(len(columns))]
|
||||
type_check_list(columns, (str,), col_names)
|
||||
|
||||
|
||||
def parse_user_args(method, *args, **kwargs):
|
||||
"""
|
||||
Parse user arguments in a function
|
||||
|
||||
Args:
|
||||
method (method): a callable function
|
||||
*args: user passed args
|
||||
**kwargs: user passed kwargs
|
||||
|
||||
Returns:
|
||||
user_filled_args (list): values of what the user passed in for the arguments,
|
||||
ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed.
|
||||
"""
|
||||
sig = inspect.signature(method)
|
||||
if 'self' in sig.parameters or 'cls' in sig.parameters:
|
||||
ba = sig.bind(method, *args, **kwargs)
|
||||
ba.apply_defaults()
|
||||
params = list(sig.parameters.keys())[1:]
|
||||
else:
|
||||
ba = sig.bind(*args, **kwargs)
|
||||
ba.apply_defaults()
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
user_filled_args = [ba.arguments.get(arg_value) for arg_value in params]
|
||||
return user_filled_args, ba.arguments
|
||||
|
||||
|
||||
def type_check_list(args, types, arg_names):
|
||||
"""
|
||||
Check the type of each parameter in the list
|
||||
|
||||
Args:
|
||||
args (list, tuple): a list or tuple of any variable
|
||||
types (tuple): tuple of all valid types for arg
|
||||
arg_names (list, tuple of str): the names of args
|
||||
|
||||
Returns:
|
||||
Exception: when the type is not correct, otherwise nothing
|
||||
"""
|
||||
type_check(args, (list, tuple,), arg_names)
|
||||
if len(args) != len(arg_names):
|
||||
raise ValueError("List of arguments is not the same length as argument_names.")
|
||||
for arg, arg_name in zip(args, arg_names):
|
||||
type_check(arg, types, arg_name)
|
||||
|
||||
|
||||
def type_check(arg, types, arg_name):
|
||||
"""
|
||||
Check the type of the parameter
|
||||
|
||||
Args:
|
||||
arg : any variable
|
||||
types (tuple): tuple of all valid types for arg
|
||||
arg_name (str): the name of arg
|
||||
|
||||
Returns:
|
||||
Exception: when the type is not correct, otherwise nothing
|
||||
"""
|
||||
# handle special case of booleans being a subclass of ints
|
||||
print_value = '\"\"' if repr(arg) == repr('') else arg
|
||||
|
||||
if int in types and bool not in types:
|
||||
if isinstance(arg, bool):
|
||||
raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types))
|
||||
if not isinstance(arg, types):
|
||||
raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types))
|
||||
|
||||
|
||||
def check_filename(path):
|
||||
"""
|
||||
check the filename in the path
|
||||
|
||||
Args:
|
||||
path (str): the path
|
||||
|
||||
Returns:
|
||||
Exception: when error
|
||||
"""
|
||||
if not isinstance(path, str):
|
||||
raise TypeError("path: {} is not string".format(path))
|
||||
filename = os.path.basename(path)
|
||||
|
||||
# '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`',
|
||||
# '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>',
|
||||
# '*', '(', '%', ')', '-', '=', '{', '?', '$'
|
||||
forbidden_symbols = set(r'\/:*?"<>|`&\';')
|
||||
|
||||
if set(filename) & forbidden_symbols:
|
||||
raise ValueError(r"filename should not contains \/:*?\"<>|`&;\'")
|
||||
|
||||
if filename.startswith(' ') or filename.endswith(' '):
|
||||
raise ValueError("filename should not start/end with space")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_dir(dataset_dir):
|
||||
if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK):
|
||||
raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir))
|
||||
|
||||
|
||||
def check_file(dataset_file):
|
||||
check_filename(dataset_file)
|
||||
if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK):
|
||||
raise ValueError("The file {} does not exist or permission denied!".format(dataset_file))
|
||||
|
||||
|
||||
def check_sampler_shuffle_shard_options(param_dict):
|
||||
"""
|
||||
Check for valid shuffle, sampler, num_shards, and shard_id inputs.
|
||||
Args:
|
||||
param_dict (dict): param_dict
|
||||
|
||||
Returns:
|
||||
Exception: ValueError or RuntimeError if error
|
||||
"""
|
||||
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
|
||||
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
|
||||
|
||||
type_check(sampler, (type(None), samplers.BuiltinSampler, samplers.Sampler), "sampler")
|
||||
|
||||
if sampler is not None:
|
||||
if shuffle is not None:
|
||||
raise RuntimeError("sampler and shuffle cannot be specified at the same time.")
|
||||
|
||||
if num_shards is not None:
|
||||
check_pos_int32(num_shards)
|
||||
if shard_id is None:
|
||||
raise RuntimeError("num_shards is specified and currently requires shard_id as well.")
|
||||
check_value(shard_id, [0, num_shards - 1], "shard_id")
|
||||
|
||||
if num_shards is None and shard_id is not None:
|
||||
raise RuntimeError("shard_id is specified but num_shards is not.")
|
||||
|
||||
|
||||
def check_padding_options(param_dict):
|
||||
"""
|
||||
Check for valid padded_sample and num_padded of padded samples
|
||||
|
||||
Args:
|
||||
param_dict (dict): param_dict
|
||||
|
||||
Returns:
|
||||
Exception: ValueError or RuntimeError if error
|
||||
"""
|
||||
|
||||
columns_list = param_dict.get('columns_list')
|
||||
block_reader = param_dict.get('block_reader')
|
||||
padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded')
|
||||
if padded_sample is not None:
|
||||
if num_padded is None:
|
||||
raise RuntimeError("padded_sample is specified and requires num_padded as well.")
|
||||
if num_padded < 0:
|
||||
raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded))
|
||||
if columns_list is None:
|
||||
raise RuntimeError("padded_sample is specified and requires columns_list as well.")
|
||||
for column in columns_list:
|
||||
if column not in padded_sample:
|
||||
raise ValueError("padded_sample cannot match columns_list.")
|
||||
if block_reader:
|
||||
raise RuntimeError("block_reader and padded_sample cannot be specified at the same time.")
|
||||
|
||||
if padded_sample is None and num_padded is not None:
|
||||
raise RuntimeError("num_padded is specified but padded_sample is not.")
|
||||
|
||||
|
||||
def check_num_parallel_workers(value):
|
||||
type_check(value, (int,), "num_parallel_workers")
|
||||
if value < 1 or value > cpu_count():
|
||||
raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count()))
|
||||
|
||||
|
||||
def check_num_samples(value):
|
||||
type_check(value, (int,), "num_samples")
|
||||
check_value(value, [0, INT32_MAX], "num_samples")
|
||||
|
||||
|
||||
def validate_dataset_param_value(param_list, param_dict, param_type):
|
||||
for param_name in param_list:
|
||||
if param_dict.get(param_name) is not None:
|
||||
if param_name == 'num_parallel_workers':
|
||||
check_num_parallel_workers(param_dict.get(param_name))
|
||||
if param_name == 'num_samples':
|
||||
check_num_samples(param_dict.get(param_name))
|
||||
else:
|
||||
type_check(param_dict.get(param_name), (param_type,), param_name)
|
||||
|
||||
|
||||
def check_gnn_list_or_ndarray(param, param_name):
|
||||
"""
|
||||
Check if the input parameter is list or numpy.ndarray.
|
||||
|
||||
Args:
|
||||
param (list, nd.ndarray): param
|
||||
param_name (str): param_name
|
||||
|
||||
Returns:
|
||||
Exception: TypeError if error
|
||||
"""
|
||||
|
||||
type_check(param, (list, np.ndarray), param_name)
|
||||
if isinstance(param, list):
|
||||
param_names = ["param_{0}".format(i) for i in range(len(param))]
|
||||
type_check_list(param, (int,), param_names)
|
||||
|
||||
elif isinstance(param, np.ndarray):
|
||||
if not param.dtype == np.int32:
|
||||
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
|
||||
param_name, param.dtype))
|
File diff suppressed because it is too large
Load Diff
|
@ -98,7 +98,7 @@ class Ngram(cde.NgramOp):
|
|||
"""
|
||||
|
||||
@check_ngram
|
||||
def __init__(self, n, left_pad=None, right_pad=None, separator=None):
|
||||
def __init__(self, n, left_pad=("", 0), right_pad=("", 0), separator=" "):
|
||||
super().__init__(ngrams=n, l_pad_len=left_pad[1], r_pad_len=right_pad[1], l_pad_token=left_pad[0],
|
||||
r_pad_token=right_pad[0], separator=separator)
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ __all__ = [
|
|||
"Vocab", "to_str", "to_bytes"
|
||||
]
|
||||
|
||||
|
||||
class Vocab(cde.Vocab):
|
||||
"""
|
||||
Vocab object that is used to lookup a word.
|
||||
|
@ -38,7 +39,7 @@ class Vocab(cde.Vocab):
|
|||
@classmethod
|
||||
@check_from_dataset
|
||||
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None,
|
||||
special_first=None):
|
||||
special_first=True):
|
||||
"""
|
||||
Build a vocab from a dataset.
|
||||
|
||||
|
@ -62,13 +63,21 @@ class Vocab(cde.Vocab):
|
|||
special_tokens(list, optional): a list of strings, each one is a special token. for example
|
||||
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
|
||||
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens
|
||||
is specified and special_first is set to None, special_tokens will be prepended (default=None).
|
||||
is specified and special_first is set to True, special_tokens will be prepended (default=True).
|
||||
|
||||
Returns:
|
||||
Vocab, Vocab object built from dataset.
|
||||
"""
|
||||
|
||||
vocab = Vocab()
|
||||
if columns is None:
|
||||
columns = []
|
||||
if not isinstance(columns, list):
|
||||
columns = [columns]
|
||||
if freq_range is None:
|
||||
freq_range = (None, None)
|
||||
if special_tokens is None:
|
||||
special_tokens = []
|
||||
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first)
|
||||
for d in root.create_dict_iterator():
|
||||
if d is not None:
|
||||
|
@ -77,7 +86,7 @@ class Vocab(cde.Vocab):
|
|||
|
||||
@classmethod
|
||||
@check_from_list
|
||||
def from_list(cls, word_list, special_tokens=None, special_first=None):
|
||||
def from_list(cls, word_list, special_tokens=None, special_first=True):
|
||||
"""
|
||||
Build a vocab object from a list of word.
|
||||
|
||||
|
@ -86,29 +95,33 @@ class Vocab(cde.Vocab):
|
|||
special_tokens(list, optional): a list of strings, each one is a special token. for example
|
||||
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
|
||||
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens
|
||||
is specified and special_first is set to None, special_tokens will be prepended (default=None).
|
||||
is specified and special_first is set to True, special_tokens will be prepended (default=True).
|
||||
"""
|
||||
|
||||
if special_tokens is None:
|
||||
special_tokens = []
|
||||
return super().from_list(word_list, special_tokens, special_first)
|
||||
|
||||
@classmethod
|
||||
@check_from_file
|
||||
def from_file(cls, file_path, delimiter=None, vocab_size=None, special_tokens=None, special_first=None):
|
||||
def from_file(cls, file_path, delimiter="", vocab_size=None, special_tokens=None, special_first=True):
|
||||
"""
|
||||
Build a vocab object from a list of word.
|
||||
|
||||
Args:
|
||||
file_path (str): path to the file which contains the vocab list.
|
||||
delimiter (str, optional): a delimiter to break up each line in file, the first element is taken to be
|
||||
the word (default=None).
|
||||
the word (default="").
|
||||
vocab_size (int, optional): number of words to read from file_path (default=None, all words are taken).
|
||||
special_tokens (list, optional): a list of strings, each one is a special token. for example
|
||||
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
|
||||
special_first (bool, optional): whether special_tokens will be prepended/appended to vocab,
|
||||
If special_tokens is specified and special_first is set to None,
|
||||
special_tokens will be prepended (default=None).
|
||||
If special_tokens is specified and special_first is set to True,
|
||||
special_tokens will be prepended (default=True).
|
||||
"""
|
||||
|
||||
if vocab_size is None:
|
||||
vocab_size = -1
|
||||
if special_tokens is None:
|
||||
special_tokens = []
|
||||
return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -17,23 +17,22 @@ validators for text ops
|
|||
"""
|
||||
|
||||
from functools import wraps
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
from mindspore._c_expression import typing
|
||||
from ..transforms.validators import check_uint32, check_pos_int64
|
||||
|
||||
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, check_positive, \
|
||||
INT32_MAX, check_value
|
||||
|
||||
|
||||
def check_unique_list_of_words(words, arg_name):
|
||||
"""Check that words is a list and each element is a str without any duplication"""
|
||||
|
||||
if not isinstance(words, list):
|
||||
raise ValueError(arg_name + " needs to be a list of words of type string.")
|
||||
type_check(words, (list,), arg_name)
|
||||
words_set = set()
|
||||
for word in words:
|
||||
if not isinstance(word, str):
|
||||
raise ValueError("each word in " + arg_name + " needs to be type str.")
|
||||
type_check(word, (str,), arg_name)
|
||||
if word in words_set:
|
||||
raise ValueError(arg_name + " contains duplicate word: " + word + ".")
|
||||
words_set.add(word)
|
||||
|
@ -45,21 +44,14 @@ def check_lookup(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
vocab, unknown = (list(args) + 2 * [None])[:2]
|
||||
if "vocab" in kwargs:
|
||||
vocab = kwargs.get("vocab")
|
||||
if "unknown" in kwargs:
|
||||
unknown = kwargs.get("unknown")
|
||||
[vocab, unknown], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if unknown is not None:
|
||||
if not (isinstance(unknown, int) and unknown >= 0):
|
||||
raise ValueError("unknown needs to be a non-negative integer.")
|
||||
type_check(unknown, (int,), "unknown")
|
||||
check_positive(unknown)
|
||||
type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.")
|
||||
|
||||
if not isinstance(vocab, cde.Vocab):
|
||||
raise ValueError("vocab is not an instance of cde.Vocab.")
|
||||
|
||||
kwargs["vocab"] = vocab
|
||||
kwargs["unknown"] = unknown
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -69,50 +61,15 @@ def check_from_file(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
file_path, delimiter, vocab_size, special_tokens, special_first = (list(args) + 5 * [None])[:5]
|
||||
if "file_path" in kwargs:
|
||||
file_path = kwargs.get("file_path")
|
||||
if "delimiter" in kwargs:
|
||||
delimiter = kwargs.get("delimiter")
|
||||
if "vocab_size" in kwargs:
|
||||
vocab_size = kwargs.get("vocab_size")
|
||||
if "special_tokens" in kwargs:
|
||||
special_tokens = kwargs.get("special_tokens")
|
||||
if "special_first" in kwargs:
|
||||
special_first = kwargs.get("special_first")
|
||||
|
||||
if not isinstance(file_path, str):
|
||||
raise ValueError("file_path needs to be str.")
|
||||
|
||||
if delimiter is not None:
|
||||
if not isinstance(delimiter, str):
|
||||
raise ValueError("delimiter needs to be str.")
|
||||
else:
|
||||
delimiter = ""
|
||||
if vocab_size is not None:
|
||||
if not (isinstance(vocab_size, int) and vocab_size > 0):
|
||||
raise ValueError("vocab size needs to be a positive integer.")
|
||||
else:
|
||||
vocab_size = -1
|
||||
|
||||
if special_first is None:
|
||||
special_first = True
|
||||
|
||||
if not isinstance(special_first, bool):
|
||||
raise ValueError("special_first needs to be a boolean value")
|
||||
|
||||
if special_tokens is None:
|
||||
special_tokens = []
|
||||
|
||||
[file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args,
|
||||
**kwargs)
|
||||
check_unique_list_of_words(special_tokens, "special_tokens")
|
||||
type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"])
|
||||
if vocab_size is not None:
|
||||
check_value(vocab_size, (-1, INT32_MAX), "vocab_size")
|
||||
type_check(special_first, (bool,), special_first)
|
||||
|
||||
kwargs["file_path"] = file_path
|
||||
kwargs["delimiter"] = delimiter
|
||||
kwargs["vocab_size"] = vocab_size
|
||||
kwargs["special_tokens"] = special_tokens
|
||||
kwargs["special_first"] = special_first
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -122,16 +79,10 @@ def check_from_list(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
word_list, special_tokens, special_first = (list(args) + 3 * [None])[:3]
|
||||
if "word_list" in kwargs:
|
||||
word_list = kwargs.get("word_list")
|
||||
if "special_tokens" in kwargs:
|
||||
special_tokens = kwargs.get("special_tokens")
|
||||
if "special_first" in kwargs:
|
||||
special_first = kwargs.get("special_first")
|
||||
if special_tokens is None:
|
||||
special_tokens = []
|
||||
[word_list, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
word_set = check_unique_list_of_words(word_list, "word_list")
|
||||
if special_tokens is not None:
|
||||
token_set = check_unique_list_of_words(special_tokens, "special_tokens")
|
||||
|
||||
intersect = word_set.intersection(token_set)
|
||||
|
@ -139,16 +90,9 @@ def check_from_list(method):
|
|||
if intersect != set():
|
||||
raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".")
|
||||
|
||||
if special_first is None:
|
||||
special_first = True
|
||||
type_check(special_first, (bool,), "special_first")
|
||||
|
||||
if not isinstance(special_first, bool):
|
||||
raise ValueError("special_first needs to be a boolean value.")
|
||||
|
||||
kwargs["word_list"] = word_list
|
||||
kwargs["special_tokens"] = special_tokens
|
||||
kwargs["special_first"] = special_first
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -158,18 +102,15 @@ def check_from_dict(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
word_dict, = (list(args) + [None])[:1]
|
||||
if "word_dict" in kwargs:
|
||||
word_dict = kwargs.get("word_dict")
|
||||
if not isinstance(word_dict, dict):
|
||||
raise ValueError("word_dict needs to be a list of word,id pairs.")
|
||||
[word_dict], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
type_check(word_dict, (dict,), "word_dict")
|
||||
|
||||
for word, word_id in word_dict.items():
|
||||
if not isinstance(word, str):
|
||||
raise ValueError("Each word in word_dict needs to be type string.")
|
||||
if not (isinstance(word_id, int) and word_id >= 0):
|
||||
raise ValueError("Each word id needs to be positive integer.")
|
||||
kwargs["word_dict"] = word_dict
|
||||
return method(self, **kwargs)
|
||||
type_check(word, (str,), "word")
|
||||
type_check(word_id, (int,), "word_id")
|
||||
check_value(word_id, (-1, INT32_MAX), "word_id")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -179,23 +120,8 @@ def check_jieba_init(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
hmm_path, mp_path, model = (list(args) + 3 * [None])[:3]
|
||||
|
||||
if "hmm_path" in kwargs:
|
||||
hmm_path = kwargs.get("hmm_path")
|
||||
if "mp_path" in kwargs:
|
||||
mp_path = kwargs.get("mp_path")
|
||||
if hmm_path is None:
|
||||
raise ValueError(
|
||||
"The dict of HMMSegment in cppjieba is not provided.")
|
||||
kwargs["hmm_path"] = hmm_path
|
||||
if mp_path is None:
|
||||
raise ValueError(
|
||||
"The dict of MPSegment in cppjieba is not provided.")
|
||||
kwargs["mp_path"] = mp_path
|
||||
if model is not None:
|
||||
kwargs["model"] = model
|
||||
return method(self, **kwargs)
|
||||
parse_user_args(method, *args, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -205,19 +131,12 @@ def check_jieba_add_word(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
word, freq = (list(args) + 2 * [None])[:2]
|
||||
|
||||
if "word" in kwargs:
|
||||
word = kwargs.get("word")
|
||||
if "freq" in kwargs:
|
||||
freq = kwargs.get("freq")
|
||||
[word, freq], _ = parse_user_args(method, *args, **kwargs)
|
||||
if word is None:
|
||||
raise ValueError("word is not provided.")
|
||||
kwargs["word"] = word
|
||||
if freq is not None:
|
||||
check_uint32(freq)
|
||||
kwargs["freq"] = freq
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -227,13 +146,8 @@ def check_jieba_add_dict(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
user_dict = (list(args) + [None])[0]
|
||||
if "user_dict" in kwargs:
|
||||
user_dict = kwargs.get("user_dict")
|
||||
if user_dict is None:
|
||||
raise ValueError("user_dict is not provided.")
|
||||
kwargs["user_dict"] = user_dict
|
||||
return method(self, **kwargs)
|
||||
parse_user_args(method, *args, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -244,69 +158,39 @@ def check_from_dataset(method):
|
|||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
||||
dataset, columns, freq_range, top_k, special_tokens, special_first = (list(args) + 6 * [None])[:6]
|
||||
if "dataset" in kwargs:
|
||||
dataset = kwargs.get("dataset")
|
||||
if "columns" in kwargs:
|
||||
columns = kwargs.get("columns")
|
||||
if "freq_range" in kwargs:
|
||||
freq_range = kwargs.get("freq_range")
|
||||
if "top_k" in kwargs:
|
||||
top_k = kwargs.get("top_k")
|
||||
if "special_tokens" in kwargs:
|
||||
special_tokens = kwargs.get("special_tokens")
|
||||
if "special_first" in kwargs:
|
||||
special_first = kwargs.get("special_first")
|
||||
|
||||
if columns is None:
|
||||
columns = []
|
||||
|
||||
[_, columns, freq_range, top_k, special_tokens, special_first], _ = parse_user_args(method, *args,
|
||||
**kwargs)
|
||||
if columns is not None:
|
||||
if not isinstance(columns, list):
|
||||
columns = [columns]
|
||||
col_names = ["col_{0}".format(i) for i in range(len(columns))]
|
||||
type_check_list(columns, (str,), col_names)
|
||||
|
||||
for column in columns:
|
||||
if not isinstance(column, str):
|
||||
raise ValueError("columns need to be a list of strings.")
|
||||
if freq_range is not None:
|
||||
type_check(freq_range, (tuple,), "freq_range")
|
||||
|
||||
if freq_range is None:
|
||||
freq_range = (None, None)
|
||||
|
||||
if not isinstance(freq_range, tuple) or len(freq_range) != 2:
|
||||
raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
|
||||
if len(freq_range) != 2:
|
||||
raise ValueError("freq_range needs to be a tuple of 2 integers or an int and a None.")
|
||||
|
||||
for num in freq_range:
|
||||
if num is not None and (not isinstance(num, int)):
|
||||
raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
|
||||
raise ValueError(
|
||||
"freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
|
||||
|
||||
if isinstance(freq_range[0], int) and isinstance(freq_range[1], int):
|
||||
if freq_range[0] > freq_range[1] or freq_range[0] < 0:
|
||||
raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).")
|
||||
|
||||
if top_k is not None and (not isinstance(top_k, int)):
|
||||
raise ValueError("top_k needs to be a positive integer.")
|
||||
type_check(top_k, (int, type(None)), "top_k")
|
||||
|
||||
if isinstance(top_k, int) and top_k <= 0:
|
||||
raise ValueError("top_k needs to be a positive integer.")
|
||||
|
||||
if special_first is None:
|
||||
special_first = True
|
||||
|
||||
if special_tokens is None:
|
||||
special_tokens = []
|
||||
|
||||
if not isinstance(special_first, bool):
|
||||
raise ValueError("special_first needs to be a boolean value.")
|
||||
if isinstance(top_k, int):
|
||||
check_value(top_k, (0, INT32_MAX), "top_k")
|
||||
type_check(special_first, (bool,), "special_first")
|
||||
|
||||
if special_tokens is not None:
|
||||
check_unique_list_of_words(special_tokens, "special_tokens")
|
||||
|
||||
kwargs["dataset"] = dataset
|
||||
kwargs["columns"] = columns
|
||||
kwargs["freq_range"] = freq_range
|
||||
kwargs["top_k"] = top_k
|
||||
kwargs["special_tokens"] = special_tokens
|
||||
kwargs["special_first"] = special_first
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -316,15 +200,7 @@ def check_ngram(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
n, left_pad, right_pad, separator = (list(args) + 4 * [None])[:4]
|
||||
if "n" in kwargs:
|
||||
n = kwargs.get("n")
|
||||
if "left_pad" in kwargs:
|
||||
left_pad = kwargs.get("left_pad")
|
||||
if "right_pad" in kwargs:
|
||||
right_pad = kwargs.get("right_pad")
|
||||
if "separator" in kwargs:
|
||||
separator = kwargs.get("separator")
|
||||
[n, left_pad, right_pad, separator], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if isinstance(n, int):
|
||||
n = [n]
|
||||
|
@ -332,15 +208,9 @@ def check_ngram(method):
|
|||
if not (isinstance(n, list) and n != []):
|
||||
raise ValueError("n needs to be a non-empty list of positive integers.")
|
||||
|
||||
for gram in n:
|
||||
if not (isinstance(gram, int) and gram > 0):
|
||||
raise ValueError("n in ngram needs to be a positive number.")
|
||||
|
||||
if left_pad is None:
|
||||
left_pad = ("", 0)
|
||||
|
||||
if right_pad is None:
|
||||
right_pad = ("", 0)
|
||||
for i, gram in enumerate(n):
|
||||
type_check(gram, (int,), "gram[{0}]".format(i))
|
||||
check_value(gram, (0, INT32_MAX), "gram_{}".format(i))
|
||||
|
||||
if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance(
|
||||
left_pad[1], int)):
|
||||
|
@ -353,11 +223,7 @@ def check_ngram(method):
|
|||
if not (left_pad[1] >= 0 and right_pad[1] >= 0):
|
||||
raise ValueError("padding width need to be positive numbers.")
|
||||
|
||||
if separator is None:
|
||||
separator = " "
|
||||
|
||||
if not isinstance(separator, str):
|
||||
raise ValueError("separator needs to be a string.")
|
||||
type_check(separator, (str,), "separator")
|
||||
|
||||
kwargs["n"] = n
|
||||
kwargs["left_pad"] = left_pad
|
||||
|
@ -374,16 +240,8 @@ def check_pair_truncate(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
max_length = (list(args) + [None])[0]
|
||||
if "max_length" in kwargs:
|
||||
max_length = kwargs.get("max_length")
|
||||
if max_length is None:
|
||||
raise ValueError("max_length is not provided.")
|
||||
|
||||
check_pos_int64(max_length)
|
||||
kwargs["max_length"] = max_length
|
||||
|
||||
return method(self, **kwargs)
|
||||
parse_user_args(method, *args, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -393,22 +251,13 @@ def check_to_number(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
data_type = (list(args) + [None])[0]
|
||||
if "data_type" in kwargs:
|
||||
data_type = kwargs.get("data_type")
|
||||
|
||||
if data_type is None:
|
||||
raise ValueError("data_type is a mandatory parameter but was not provided.")
|
||||
|
||||
if not isinstance(data_type, typing.Type):
|
||||
raise TypeError("data_type is not a MindSpore data type.")
|
||||
[data_type], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(data_type, (typing.Type,), "data_type")
|
||||
|
||||
if data_type not in mstype.number_type:
|
||||
raise TypeError("data_type is not numeric data type.")
|
||||
|
||||
kwargs["data_type"] = data_type
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -418,18 +267,11 @@ def check_python_tokenizer(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
tokenizer = (list(args) + [None])[0]
|
||||
if "tokenizer" in kwargs:
|
||||
tokenizer = kwargs.get("tokenizer")
|
||||
|
||||
if tokenizer is None:
|
||||
raise ValueError("tokenizer is a mandatory parameter.")
|
||||
[tokenizer], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if not callable(tokenizer):
|
||||
raise TypeError("tokenizer is not a callable python function")
|
||||
|
||||
kwargs["tokenizer"] = tokenizer
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -18,6 +18,7 @@ from functools import wraps
|
|||
import numpy as np
|
||||
|
||||
from mindspore._c_expression import typing
|
||||
from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive
|
||||
|
||||
# POS_INT_MIN is used to limit values from starting from 0
|
||||
POS_INT_MIN = 1
|
||||
|
@ -37,106 +38,33 @@ DOUBLE_MAX_INTEGER = 9007199254740992
|
|||
DOUBLE_MIN_INTEGER = -9007199254740992
|
||||
|
||||
|
||||
def check_type(value, valid_type):
|
||||
if not isinstance(value, valid_type):
|
||||
raise ValueError("Wrong input type")
|
||||
|
||||
|
||||
def check_value(value, valid_range):
|
||||
if value < valid_range[0] or value > valid_range[1]:
|
||||
raise ValueError("Input is not within the required range")
|
||||
|
||||
|
||||
def check_range(values, valid_range):
|
||||
if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]:
|
||||
raise ValueError("Input range is not valid")
|
||||
|
||||
|
||||
def check_positive(value):
|
||||
if value <= 0:
|
||||
raise ValueError("Input must greater than 0")
|
||||
|
||||
|
||||
def check_positive_float(value, valid_max=None):
|
||||
if value <= 0 or not isinstance(value, float) or (valid_max is not None and value > valid_max):
|
||||
raise ValueError("Input need to be a valid positive float.")
|
||||
|
||||
|
||||
def check_bool(value):
|
||||
if not isinstance(value, bool):
|
||||
raise ValueError("Value needs to be a boolean.")
|
||||
|
||||
|
||||
def check_2tuple(value):
|
||||
if not (isinstance(value, tuple) and len(value) == 2):
|
||||
raise ValueError("Value needs to be a 2-tuple.")
|
||||
|
||||
|
||||
def check_list(value):
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("The input needs to be a list.")
|
||||
|
||||
|
||||
def check_uint8(value):
|
||||
if not isinstance(value, int):
|
||||
raise ValueError("The input needs to be a integer")
|
||||
check_value(value, [UINT8_MIN, UINT8_MAX])
|
||||
|
||||
|
||||
def check_uint32(value):
|
||||
if not isinstance(value, int):
|
||||
raise ValueError("The input needs to be a integer")
|
||||
check_value(value, [UINT32_MIN, UINT32_MAX])
|
||||
|
||||
|
||||
def check_pos_int32(value):
|
||||
"""Checks for int values starting from 1"""
|
||||
if not isinstance(value, int):
|
||||
raise ValueError("The input needs to be a integer")
|
||||
check_value(value, [POS_INT_MIN, INT32_MAX])
|
||||
|
||||
|
||||
def check_uint64(value):
|
||||
if not isinstance(value, int):
|
||||
raise ValueError("The input needs to be a integer")
|
||||
check_value(value, [UINT64_MIN, UINT64_MAX])
|
||||
|
||||
|
||||
def check_pos_int64(value):
|
||||
if not isinstance(value, int):
|
||||
raise ValueError("The input needs to be a integer")
|
||||
check_value(value, [UINT64_MIN, INT64_MAX])
|
||||
|
||||
|
||||
def check_pos_float32(value):
|
||||
check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER])
|
||||
|
||||
|
||||
def check_pos_float64(value):
|
||||
check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER])
|
||||
|
||||
|
||||
def check_one_hot_op(method):
|
||||
"""Wrapper method to check the parameters of one hot op."""
|
||||
def check_fill_value(method):
|
||||
"""Wrapper method to check the parameters of fill_value."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 2 * [None])[:2]
|
||||
num_classes, smoothing_rate = args
|
||||
if "num_classes" in kwargs:
|
||||
num_classes = kwargs.get("num_classes")
|
||||
if "smoothing_rate" in kwargs:
|
||||
smoothing_rate = kwargs.get("smoothing_rate")
|
||||
[fill_value], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(fill_value, (str, float, bool, int, bytes), "fill_value")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_one_hot_op(method):
|
||||
"""Wrapper method to check the parameters of one_hot_op."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
type_check(num_classes, (int,), "num_classes")
|
||||
check_positive(num_classes)
|
||||
|
||||
if num_classes is None:
|
||||
raise ValueError("num_classes")
|
||||
check_pos_int32(num_classes)
|
||||
kwargs["num_classes"] = num_classes
|
||||
if smoothing_rate is not None:
|
||||
check_value(smoothing_rate, [0., 1.])
|
||||
kwargs["smoothing_rate"] = smoothing_rate
|
||||
check_value(smoothing_rate, [0., 1.], "smoothing_rate")
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -146,35 +74,12 @@ def check_num_classes(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
num_classes = (list(args) + [None])[0]
|
||||
if "num_classes" in kwargs:
|
||||
num_classes = kwargs.get("num_classes")
|
||||
if num_classes is None:
|
||||
raise ValueError("num_classes is not provided.")
|
||||
[num_classes], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
check_pos_int32(num_classes)
|
||||
kwargs["num_classes"] = num_classes
|
||||
type_check(num_classes, (int,), "num_classes")
|
||||
check_positive(num_classes)
|
||||
|
||||
return method(self, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_fill_value(method):
|
||||
"""Wrapper method to check the parameters of fill value."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
fill_value = (list(args) + [None])[0]
|
||||
if "fill_value" in kwargs:
|
||||
fill_value = kwargs.get("fill_value")
|
||||
if fill_value is None:
|
||||
raise ValueError("fill_value is not provided.")
|
||||
if not isinstance(fill_value, (str, float, bool, int, bytes)):
|
||||
raise TypeError("fill_value must be either a primitive python str, float, bool, bytes or int")
|
||||
kwargs["fill_value"] = fill_value
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -184,17 +89,11 @@ def check_de_type(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
data_type = (list(args) + [None])[0]
|
||||
if "data_type" in kwargs:
|
||||
data_type = kwargs.get("data_type")
|
||||
[data_type], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if data_type is None:
|
||||
raise ValueError("data_type is not provided.")
|
||||
if not isinstance(data_type, typing.Type):
|
||||
raise TypeError("data_type is not a MindSpore data type.")
|
||||
kwargs["data_type"] = data_type
|
||||
type_check(data_type, (typing.Type,), "data_type")
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -204,13 +103,11 @@ def check_slice_op(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args):
|
||||
for i, arg in enumerate(args):
|
||||
if arg is not None and arg is not Ellipsis and not isinstance(arg, (int, slice, list)):
|
||||
raise TypeError("Indexing of dim " + str(i) + "is not of valid type")
|
||||
for _, arg in enumerate(args):
|
||||
type_check(arg, (int, slice, list, type(None), type(Ellipsis)), "arg")
|
||||
if isinstance(arg, list):
|
||||
for a in arg:
|
||||
if not isinstance(a, int):
|
||||
raise TypeError("Index " + a + " is not an int")
|
||||
type_check(a, (int,), "a")
|
||||
return method(self, *args)
|
||||
|
||||
return new_method
|
||||
|
@ -221,36 +118,14 @@ def check_mask_op(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
operator, constant, dtype = (list(args) + 3 * [None])[:3]
|
||||
if "operator" in kwargs:
|
||||
operator = kwargs.get("operator")
|
||||
if "constant" in kwargs:
|
||||
constant = kwargs.get("constant")
|
||||
if "dtype" in kwargs:
|
||||
dtype = kwargs.get("dtype")
|
||||
|
||||
if operator is None:
|
||||
raise ValueError("operator is not provided.")
|
||||
|
||||
if constant is None:
|
||||
raise ValueError("constant is not provided.")
|
||||
[operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
from .c_transforms import Relational
|
||||
if not isinstance(operator, Relational):
|
||||
raise TypeError("operator is not a Relational operator enum.")
|
||||
type_check(operator, (Relational,), "operator")
|
||||
type_check(constant, (str, float, bool, int, bytes), "constant")
|
||||
type_check(dtype, (typing.Type,), "dtype")
|
||||
|
||||
if not isinstance(constant, (str, float, bool, int, bytes)):
|
||||
raise TypeError("constant must be either a primitive python str, float, bool, bytes or int")
|
||||
|
||||
if dtype is not None:
|
||||
if not isinstance(dtype, typing.Type):
|
||||
raise TypeError("dtype is not a MindSpore data type.")
|
||||
kwargs["dtype"] = dtype
|
||||
|
||||
kwargs["operator"] = operator
|
||||
kwargs["constant"] = constant
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -260,22 +135,12 @@ def check_pad_end(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
pad_shape, pad_value = (list(args) + 2 * [None])[:2]
|
||||
if "pad_shape" in kwargs:
|
||||
pad_shape = kwargs.get("pad_shape")
|
||||
if "pad_value" in kwargs:
|
||||
pad_value = kwargs.get("pad_value")
|
||||
|
||||
if pad_shape is None:
|
||||
raise ValueError("pad_shape is not provided.")
|
||||
[pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if pad_value is not None:
|
||||
if not isinstance(pad_value, (str, float, bool, int, bytes)):
|
||||
raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes")
|
||||
kwargs["pad_value"] = pad_value
|
||||
|
||||
if not isinstance(pad_shape, list):
|
||||
raise TypeError("pad_shape must be a list")
|
||||
type_check(pad_value, (str, float, bool, int, bytes), "pad_value")
|
||||
type_check(pad_shape, (list,), "pad_end")
|
||||
|
||||
for dim in pad_shape:
|
||||
if dim is not None:
|
||||
|
@ -284,9 +149,7 @@ def check_pad_end(method):
|
|||
else:
|
||||
raise TypeError("a value in the list is not an integer.")
|
||||
|
||||
kwargs["pad_shape"] = pad_shape
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -296,31 +159,24 @@ def check_concat_type(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
axis, prepend, append = (list(args) + 3 * [None])[:3]
|
||||
if "prepend" in kwargs:
|
||||
prepend = kwargs.get("prepend")
|
||||
if "append" in kwargs:
|
||||
append = kwargs.get("append")
|
||||
if "axis" in kwargs:
|
||||
axis = kwargs.get("axis")
|
||||
|
||||
[axis, prepend, append], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if axis is not None:
|
||||
if not isinstance(axis, int):
|
||||
raise TypeError("axis type is not valid, must be an integer.")
|
||||
type_check(axis, (int,), "axis")
|
||||
if axis not in (0, -1):
|
||||
raise ValueError("only 1D concatenation supported.")
|
||||
kwargs["axis"] = axis
|
||||
|
||||
if prepend is not None:
|
||||
if not isinstance(prepend, (type(None), np.ndarray)):
|
||||
raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.")
|
||||
kwargs["prepend"] = prepend
|
||||
type_check(prepend, (np.ndarray,), "prepend")
|
||||
if len(prepend.shape) != 1:
|
||||
raise ValueError("can only prepend 1D arrays.")
|
||||
|
||||
if append is not None:
|
||||
if not isinstance(append, (type(None), np.ndarray)):
|
||||
raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.")
|
||||
kwargs["append"] = append
|
||||
type_check(append, (np.ndarray,), "append")
|
||||
if len(append.shape) != 1:
|
||||
raise ValueError("can only append 1D arrays.")
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -40,12 +40,14 @@ Examples:
|
|||
>>> dataset = dataset.map(input_columns="image", operations=transforms_list)
|
||||
>>> dataset = dataset.map(input_columns="label", operations=onehot_op)
|
||||
"""
|
||||
import numbers
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
from .utils import Inter, Border
|
||||
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
|
||||
check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
|
||||
check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp
|
||||
check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \
|
||||
check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \
|
||||
FLOAT_MAX_INTEGER
|
||||
|
||||
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
|
||||
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
|
||||
|
@ -57,6 +59,18 @@ DE_C_BORDER_TYPE = {Border.CONSTANT: cde.BorderType.DE_BORDER_CONSTANT,
|
|||
Border.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC}
|
||||
|
||||
|
||||
def parse_padding(padding):
|
||||
if isinstance(padding, numbers.Number):
|
||||
padding = [padding] * 4
|
||||
if len(padding) == 2:
|
||||
left = right = padding[0]
|
||||
top = bottom = padding[1]
|
||||
padding = (left, top, right, bottom,)
|
||||
if isinstance(padding, list):
|
||||
padding = tuple(padding)
|
||||
return padding
|
||||
|
||||
|
||||
class Decode(cde.DecodeOp):
|
||||
"""
|
||||
Decode the input image in RGB mode.
|
||||
|
@ -136,16 +150,22 @@ class RandomCrop(cde.RandomCropOp):
|
|||
|
||||
@check_random_crop
|
||||
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
if padding is None:
|
||||
padding = (0, 0, 0, 0)
|
||||
else:
|
||||
padding = parse_padding(padding)
|
||||
if isinstance(fill_value, int): # temporary fix
|
||||
fill_value = tuple([fill_value] * 3)
|
||||
border_type = DE_C_BORDER_TYPE[padding_mode]
|
||||
|
||||
self.size = size
|
||||
self.padding = padding
|
||||
self.pad_if_needed = pad_if_needed
|
||||
self.fill_value = fill_value
|
||||
self.padding_mode = padding_mode.value
|
||||
if padding is None:
|
||||
padding = (0, 0, 0, 0)
|
||||
if isinstance(fill_value, int): # temporary fix
|
||||
fill_value = tuple([fill_value] * 3)
|
||||
border_type = DE_C_BORDER_TYPE[padding_mode]
|
||||
|
||||
super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value)
|
||||
|
||||
|
||||
|
@ -184,16 +204,23 @@ class RandomCropWithBBox(cde.RandomCropWithBBoxOp):
|
|||
|
||||
@check_random_crop
|
||||
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
if padding is None:
|
||||
padding = (0, 0, 0, 0)
|
||||
else:
|
||||
padding = parse_padding(padding)
|
||||
|
||||
if isinstance(fill_value, int): # temporary fix
|
||||
fill_value = tuple([fill_value] * 3)
|
||||
border_type = DE_C_BORDER_TYPE[padding_mode]
|
||||
|
||||
self.size = size
|
||||
self.padding = padding
|
||||
self.pad_if_needed = pad_if_needed
|
||||
self.fill_value = fill_value
|
||||
self.padding_mode = padding_mode.value
|
||||
if padding is None:
|
||||
padding = (0, 0, 0, 0)
|
||||
if isinstance(fill_value, int): # temporary fix
|
||||
fill_value = tuple([fill_value] * 3)
|
||||
border_type = DE_C_BORDER_TYPE[padding_mode]
|
||||
|
||||
super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value)
|
||||
|
||||
|
||||
|
@ -292,6 +319,8 @@ class Resize(cde.ResizeOp):
|
|||
|
||||
@check_resize_interpolation
|
||||
def __init__(self, size, interpolation=Inter.LINEAR):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
interpoltn = DE_C_INTER_MODE[interpolation]
|
||||
|
@ -359,6 +388,8 @@ class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp):
|
|||
@check_random_resize_crop
|
||||
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
||||
interpolation=Inter.BILINEAR, max_attempts=10):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
self.size = size
|
||||
self.scale = scale
|
||||
self.ratio = ratio
|
||||
|
@ -396,6 +427,8 @@ class RandomResizedCrop(cde.RandomCropAndResizeOp):
|
|||
@check_random_resize_crop
|
||||
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
||||
interpolation=Inter.BILINEAR, max_attempts=10):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
self.size = size
|
||||
self.scale = scale
|
||||
self.ratio = ratio
|
||||
|
@ -417,6 +450,8 @@ class CenterCrop(cde.CenterCropOp):
|
|||
|
||||
@check_crop
|
||||
def __init__(self, size):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
self.size = size
|
||||
super().__init__(*size)
|
||||
|
||||
|
@ -442,12 +477,26 @@ class RandomColorAdjust(cde.RandomColorAdjustOp):
|
|||
|
||||
@check_random_color_adjust
|
||||
def __init__(self, brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0)):
|
||||
brightness = self.expand_values(brightness)
|
||||
contrast = self.expand_values(contrast)
|
||||
saturation = self.expand_values(saturation)
|
||||
hue = self.expand_values(hue, center=0, bound=(-0.5, 0.5), non_negative=False)
|
||||
|
||||
self.brightness = brightness
|
||||
self.contrast = contrast
|
||||
self.saturation = saturation
|
||||
self.hue = hue
|
||||
|
||||
super().__init__(*brightness, *contrast, *saturation, *hue)
|
||||
|
||||
def expand_values(self, value, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True):
|
||||
if isinstance(value, numbers.Number):
|
||||
value = [center - value, center + value]
|
||||
if non_negative:
|
||||
value[0] = max(0, value[0])
|
||||
check_range(value, bound)
|
||||
return (value[0], value[1])
|
||||
|
||||
|
||||
class RandomRotation(cde.RandomRotationOp):
|
||||
"""
|
||||
|
@ -485,6 +534,8 @@ class RandomRotation(cde.RandomRotationOp):
|
|||
self.expand = expand
|
||||
self.center = center
|
||||
self.fill_value = fill_value
|
||||
if isinstance(degrees, numbers.Number):
|
||||
degrees = (-degrees, degrees)
|
||||
if center is None:
|
||||
center = (-1, -1)
|
||||
if isinstance(fill_value, int): # temporary fix
|
||||
|
@ -584,6 +635,8 @@ class RandomCropDecodeResize(cde.RandomCropDecodeResizeOp):
|
|||
@check_random_resize_crop
|
||||
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
||||
interpolation=Inter.BILINEAR, max_attempts=10):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
self.size = size
|
||||
self.scale = scale
|
||||
self.ratio = ratio
|
||||
|
@ -623,12 +676,14 @@ class Pad(cde.PadOp):
|
|||
|
||||
@check_pad
|
||||
def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT):
|
||||
self.padding = padding
|
||||
self.fill_value = fill_value
|
||||
self.padding_mode = padding_mode
|
||||
padding = parse_padding(padding)
|
||||
if isinstance(fill_value, int): # temporary fix
|
||||
fill_value = tuple([fill_value] * 3)
|
||||
padding_mode = DE_C_BORDER_TYPE[padding_mode]
|
||||
|
||||
self.padding = padding
|
||||
self.fill_value = fill_value
|
||||
self.padding_mode = padding_mode
|
||||
super().__init__(*padding, padding_mode, *fill_value)
|
||||
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ import numpy as np
|
|||
from PIL import Image
|
||||
|
||||
from . import py_transforms_util as util
|
||||
from .c_transforms import parse_padding
|
||||
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
|
||||
check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \
|
||||
check_transforms_list, check_random_apply, check_ten_crop, check_num_channels, check_pad, \
|
||||
|
@ -295,6 +296,10 @@ class RandomCrop:
|
|||
|
||||
@check_random_crop
|
||||
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
|
||||
if padding is None:
|
||||
padding = (0, 0, 0, 0)
|
||||
else:
|
||||
padding = parse_padding(padding)
|
||||
self.size = size
|
||||
self.padding = padding
|
||||
self.pad_if_needed = pad_if_needed
|
||||
|
@ -753,6 +758,8 @@ class TenCrop:
|
|||
|
||||
@check_ten_crop
|
||||
def __init__(self, size, use_vertical_flip=False):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
self.size = size
|
||||
self.use_vertical_flip = use_vertical_flip
|
||||
|
||||
|
@ -877,6 +884,8 @@ class Pad:
|
|||
|
||||
@check_pad
|
||||
def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT):
|
||||
parse_padding(padding)
|
||||
|
||||
self.padding = padding
|
||||
self.fill_value = fill_value
|
||||
self.padding_mode = DE_PY_BORDER_TYPE[padding_mode]
|
||||
|
@ -1129,56 +1138,23 @@ class RandomAffine:
|
|||
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Inter.NEAREST, fill_value=0):
|
||||
# Parameter checking
|
||||
# rotation
|
||||
if isinstance(degrees, numbers.Number):
|
||||
if degrees < 0:
|
||||
raise ValueError("If degrees is a single number, it must be positive.")
|
||||
self.degrees = (-degrees, degrees)
|
||||
elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
|
||||
self.degrees = degrees
|
||||
else:
|
||||
raise TypeError("If degrees is a list or tuple, it must be of length 2.")
|
||||
|
||||
# translation
|
||||
if translate is not None:
|
||||
if isinstance(translate, (tuple, list)) and len(translate) == 2:
|
||||
for t in translate:
|
||||
if t < 0.0 or t > 1.0:
|
||||
raise ValueError("translation values should be between 0 and 1")
|
||||
else:
|
||||
raise TypeError("translate should be a list or tuple of length 2.")
|
||||
self.translate = translate
|
||||
|
||||
# scale
|
||||
if scale is not None:
|
||||
if isinstance(scale, (tuple, list)) and len(scale) == 2:
|
||||
for s in scale:
|
||||
if s <= 0:
|
||||
raise ValueError("scale values should be positive")
|
||||
else:
|
||||
raise TypeError("scale should be a list or tuple of length 2.")
|
||||
self.scale_ranges = scale
|
||||
|
||||
# shear
|
||||
if shear is not None:
|
||||
if isinstance(shear, numbers.Number):
|
||||
if shear < 0:
|
||||
raise ValueError("If shear is a single number, it must be positive.")
|
||||
self.shear = (-1 * shear, shear)
|
||||
elif isinstance(shear, (tuple, list)) and (len(shear) == 2 or len(shear) == 4):
|
||||
# X-Axis shear with [min, max]
|
||||
shear = (-1 * shear, shear)
|
||||
else:
|
||||
if len(shear) == 2:
|
||||
self.shear = [shear[0], shear[1], 0., 0.]
|
||||
shear = [shear[0], shear[1], 0., 0.]
|
||||
elif len(shear) == 4:
|
||||
self.shear = [s for s in shear]
|
||||
else:
|
||||
raise TypeError("shear should be a list or tuple and it must be of length 2 or 4.")
|
||||
else:
|
||||
shear = [s for s in shear]
|
||||
|
||||
if isinstance(degrees, numbers.Number):
|
||||
degrees = (-degrees, degrees)
|
||||
|
||||
self.degrees = degrees
|
||||
self.translate = translate
|
||||
self.scale_ranges = scale
|
||||
self.shear = shear
|
||||
|
||||
# resample
|
||||
self.resample = DE_PY_INTER_MODE[resample]
|
||||
|
||||
# fill_value
|
||||
self.fill_value = fill_value
|
||||
|
||||
def __call__(self, img):
|
||||
|
|
|
@ -16,47 +16,35 @@
|
|||
"""
|
||||
import numbers
|
||||
from functools import wraps
|
||||
|
||||
import numpy as np
|
||||
from mindspore._c_dataengine import TensorOp
|
||||
|
||||
from .utils import Inter, Border
|
||||
from ...transforms.validators import check_pos_int32, check_pos_float32, check_value, check_uint8, FLOAT_MAX_INTEGER, \
|
||||
check_bool, check_2tuple, check_range, check_list, check_type, check_positive, INT32_MAX
|
||||
|
||||
|
||||
def check_inter_mode(mode):
|
||||
if not isinstance(mode, Inter):
|
||||
raise ValueError("Invalid interpolation mode.")
|
||||
|
||||
|
||||
def check_border_type(mode):
|
||||
if not isinstance(mode, Border):
|
||||
raise ValueError("Invalid padding mode.")
|
||||
from ...core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
|
||||
check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list
|
||||
|
||||
|
||||
def check_crop_size(size):
|
||||
"""Wrapper method to check the parameters of crop size."""
|
||||
type_check(size, (int, list, tuple), "size")
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
check_value(size, (1, FLOAT_MAX_INTEGER))
|
||||
elif isinstance(size, (tuple, list)) and len(size) == 2:
|
||||
size = size
|
||||
for value in size:
|
||||
check_value(value, (1, FLOAT_MAX_INTEGER))
|
||||
else:
|
||||
raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
|
||||
for value in size:
|
||||
check_pos_int32(value)
|
||||
return size
|
||||
|
||||
|
||||
def check_resize_size(size):
|
||||
"""Wrapper method to check the parameters of resize."""
|
||||
if isinstance(size, int):
|
||||
check_pos_int32(size)
|
||||
check_value(size, (1, FLOAT_MAX_INTEGER))
|
||||
elif isinstance(size, (tuple, list)) and len(size) == 2:
|
||||
for value in size:
|
||||
check_value(value, (1, INT32_MAX))
|
||||
for i, value in enumerate(size):
|
||||
check_value(value, (1, INT32_MAX), "size at dim {0}".format(i))
|
||||
else:
|
||||
raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
|
||||
return size
|
||||
|
||||
|
||||
def check_normalize_c_param(mean, std):
|
||||
|
@ -72,9 +60,9 @@ def check_normalize_py_param(mean, std):
|
|||
if len(mean) != len(std):
|
||||
raise ValueError("Length of mean and std must be equal")
|
||||
for mean_value in mean:
|
||||
check_value(mean_value, [0., 1.])
|
||||
check_value(mean_value, [0., 1.], "mean_value")
|
||||
for std_value in std:
|
||||
check_value(std_value, [0., 1.])
|
||||
check_value(std_value, [0., 1.], "std_value")
|
||||
|
||||
|
||||
def check_fill_value(fill_value):
|
||||
|
@ -85,66 +73,37 @@ def check_fill_value(fill_value):
|
|||
check_uint8(value)
|
||||
else:
|
||||
raise TypeError("fill_value should be a single integer or a 3-tuple.")
|
||||
return fill_value
|
||||
|
||||
|
||||
def check_padding(padding):
|
||||
"""Parsing the padding arguments and check if it is legal."""
|
||||
if isinstance(padding, numbers.Number):
|
||||
top = bottom = left = right = padding
|
||||
|
||||
elif isinstance(padding, (tuple, list)):
|
||||
if len(padding) == 2:
|
||||
left = right = padding[0]
|
||||
top = bottom = padding[1]
|
||||
elif len(padding) == 4:
|
||||
left = padding[0]
|
||||
top = padding[1]
|
||||
right = padding[2]
|
||||
bottom = padding[3]
|
||||
else:
|
||||
type_check(padding, (tuple, list, numbers.Number), "padding")
|
||||
if isinstance(padding, (tuple, list)):
|
||||
if len(padding) not in (2, 4):
|
||||
raise ValueError("The size of the padding list or tuple should be 2 or 4.")
|
||||
else:
|
||||
raise TypeError("Padding can be any of: a number, a tuple or list of size 2 or 4.")
|
||||
if not (isinstance(left, int) and isinstance(top, int) and isinstance(right, int) and isinstance(bottom, int)):
|
||||
raise TypeError("Padding value should be integer.")
|
||||
if left < 0 or top < 0 or right < 0 or bottom < 0:
|
||||
raise ValueError("Padding value could not be negative.")
|
||||
return left, top, right, bottom
|
||||
for i, pad_value in enumerate(padding):
|
||||
type_check(pad_value, (int,), "padding[{}]".format(i))
|
||||
check_value(pad_value, (0, INT32_MAX), "pad_value")
|
||||
|
||||
|
||||
def check_degrees(degrees):
|
||||
"""Check if the degrees is legal."""
|
||||
type_check(degrees, (numbers.Number, list, tuple), "degrees")
|
||||
if isinstance(degrees, numbers.Number):
|
||||
if degrees < 0:
|
||||
raise ValueError("If degrees is a single number, it cannot be negative.")
|
||||
degrees = (-degrees, degrees)
|
||||
check_value(degrees, (0, float("inf")), "degrees")
|
||||
elif isinstance(degrees, (list, tuple)):
|
||||
if len(degrees) != 2:
|
||||
raise TypeError("If degrees is a sequence, the length must be 2.")
|
||||
else:
|
||||
raise TypeError("Degrees must be a single non-negative number or a sequence")
|
||||
return degrees
|
||||
|
||||
|
||||
def check_random_color_adjust_param(value, input_name, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True):
|
||||
"""Check the parameters in random color adjust operation."""
|
||||
type_check(value, (numbers.Number, list, tuple), input_name)
|
||||
if isinstance(value, numbers.Number):
|
||||
if value < 0:
|
||||
raise ValueError("The input value of {} cannot be negative.".format(input_name))
|
||||
# convert value into a range
|
||||
value = [center - value, center + value]
|
||||
if non_negative:
|
||||
value[0] = max(0, value[0])
|
||||
elif isinstance(value, (list, tuple)) and len(value) == 2:
|
||||
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
||||
raise ValueError("Please check your value range of {} is valid and "
|
||||
"within the bound {}".format(input_name, bound))
|
||||
else:
|
||||
raise TypeError("Input of {} should be either a single value, or a list/tuple of "
|
||||
"length 2.".format(input_name))
|
||||
factor = (value[0], value[1])
|
||||
return factor
|
||||
check_range(value, bound)
|
||||
|
||||
|
||||
def check_erasing_value(value):
|
||||
|
@ -159,15 +118,10 @@ def check_crop(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
size = (list(args) + [None])[0]
|
||||
if "size" in kwargs:
|
||||
size = kwargs.get("size")
|
||||
if size is None:
|
||||
raise ValueError("size is not provided.")
|
||||
size = check_crop_size(size)
|
||||
kwargs["size"] = size
|
||||
[size], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_crop_size(size)
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -177,23 +131,12 @@ def check_resize_interpolation(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 2 * [None])[:2]
|
||||
size, interpolation = args
|
||||
if "size" in kwargs:
|
||||
size = kwargs.get("size")
|
||||
if "interpolation" in kwargs:
|
||||
interpolation = kwargs.get("interpolation")
|
||||
|
||||
if size is None:
|
||||
raise ValueError("size is not provided.")
|
||||
size = check_resize_size(size)
|
||||
kwargs["size"] = size
|
||||
|
||||
[size, interpolation], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_resize_size(size)
|
||||
if interpolation is not None:
|
||||
check_inter_mode(interpolation)
|
||||
kwargs["interpolation"] = interpolation
|
||||
type_check(interpolation, (Inter,), "interpolation")
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -203,16 +146,10 @@ def check_resize(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
size = (list(args) + [None])[0]
|
||||
if "size" in kwargs:
|
||||
size = kwargs.get("size")
|
||||
[size], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_resize_size(size)
|
||||
|
||||
if size is None:
|
||||
raise ValueError("size is not provided.")
|
||||
size = check_resize_size(size)
|
||||
kwargs["size"] = size
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -222,39 +159,20 @@ def check_random_resize_crop(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 5 * [None])[:5]
|
||||
size, scale, ratio, interpolation, max_attempts = args
|
||||
if "size" in kwargs:
|
||||
size = kwargs.get("size")
|
||||
if "scale" in kwargs:
|
||||
scale = kwargs.get("scale")
|
||||
if "ratio" in kwargs:
|
||||
ratio = kwargs.get("ratio")
|
||||
if "interpolation" in kwargs:
|
||||
interpolation = kwargs.get("interpolation")
|
||||
if "max_attempts" in kwargs:
|
||||
max_attempts = kwargs.get("max_attempts")
|
||||
|
||||
if size is None:
|
||||
raise ValueError("size is not provided.")
|
||||
size = check_crop_size(size)
|
||||
kwargs["size"] = size
|
||||
[size, scale, ratio, interpolation, max_attempts], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_crop_size(size)
|
||||
|
||||
if scale is not None:
|
||||
check_range(scale, [0, FLOAT_MAX_INTEGER])
|
||||
kwargs["scale"] = scale
|
||||
if ratio is not None:
|
||||
check_range(ratio, [0, FLOAT_MAX_INTEGER])
|
||||
check_positive(ratio[0])
|
||||
kwargs["ratio"] = ratio
|
||||
check_positive(ratio[0], "ratio[0]")
|
||||
if interpolation is not None:
|
||||
check_inter_mode(interpolation)
|
||||
kwargs["interpolation"] = interpolation
|
||||
type_check(interpolation, (Inter,), "interpolation")
|
||||
if max_attempts is not None:
|
||||
check_pos_int32(max_attempts)
|
||||
kwargs["max_attempts"] = max_attempts
|
||||
check_value(max_attempts, (1, FLOAT_MAX_INTEGER))
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -264,14 +182,11 @@ def check_prob(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
prob = (list(args) + [None])[0]
|
||||
if "prob" in kwargs:
|
||||
prob = kwargs.get("prob")
|
||||
if prob is not None:
|
||||
check_value(prob, [0., 1.])
|
||||
kwargs["prob"] = prob
|
||||
[prob], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(prob, (float, int,), "prob")
|
||||
check_value(prob, [0., 1.], "prob")
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -281,22 +196,10 @@ def check_normalize_c(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 2 * [None])[:2]
|
||||
mean, std = args
|
||||
if "mean" in kwargs:
|
||||
mean = kwargs.get("mean")
|
||||
if "std" in kwargs:
|
||||
std = kwargs.get("std")
|
||||
|
||||
if mean is None:
|
||||
raise ValueError("mean is not provided.")
|
||||
if std is None:
|
||||
raise ValueError("std is not provided.")
|
||||
[mean, std], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_normalize_c_param(mean, std)
|
||||
kwargs["mean"] = mean
|
||||
kwargs["std"] = std
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -306,22 +209,10 @@ def check_normalize_py(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 2 * [None])[:2]
|
||||
mean, std = args
|
||||
if "mean" in kwargs:
|
||||
mean = kwargs.get("mean")
|
||||
if "std" in kwargs:
|
||||
std = kwargs.get("std")
|
||||
|
||||
if mean is None:
|
||||
raise ValueError("mean is not provided.")
|
||||
if std is None:
|
||||
raise ValueError("std is not provided.")
|
||||
[mean, std], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_normalize_py_param(mean, std)
|
||||
kwargs["mean"] = mean
|
||||
kwargs["std"] = std
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -331,38 +222,17 @@ def check_random_crop(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 5 * [None])[:5]
|
||||
size, padding, pad_if_needed, fill_value, padding_mode = args
|
||||
|
||||
if "size" in kwargs:
|
||||
size = kwargs.get("size")
|
||||
if "padding" in kwargs:
|
||||
padding = kwargs.get("padding")
|
||||
if "fill_value" in kwargs:
|
||||
fill_value = kwargs.get("fill_value")
|
||||
if "padding_mode" in kwargs:
|
||||
padding_mode = kwargs.get("padding_mode")
|
||||
if "pad_if_needed" in kwargs:
|
||||
pad_if_needed = kwargs.get("pad_if_needed")
|
||||
|
||||
if size is None:
|
||||
raise ValueError("size is not provided.")
|
||||
size = check_crop_size(size)
|
||||
kwargs["size"] = size
|
||||
|
||||
[size, padding, pad_if_needed, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_crop_size(size)
|
||||
type_check(pad_if_needed, (bool,), "pad_if_needed")
|
||||
if padding is not None:
|
||||
padding = check_padding(padding)
|
||||
kwargs["padding"] = padding
|
||||
check_padding(padding)
|
||||
if fill_value is not None:
|
||||
fill_value = check_fill_value(fill_value)
|
||||
kwargs["fill_value"] = fill_value
|
||||
check_fill_value(fill_value)
|
||||
if padding_mode is not None:
|
||||
check_border_type(padding_mode)
|
||||
kwargs["padding_mode"] = padding_mode
|
||||
if pad_if_needed is not None:
|
||||
kwargs["pad_if_needed"] = pad_if_needed
|
||||
type_check(padding_mode, (Border,), "padding_mode")
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -372,27 +242,13 @@ def check_random_color_adjust(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 4 * [None])[:4]
|
||||
brightness, contrast, saturation, hue = args
|
||||
if "brightness" in kwargs:
|
||||
brightness = kwargs.get("brightness")
|
||||
if "contrast" in kwargs:
|
||||
contrast = kwargs.get("contrast")
|
||||
if "saturation" in kwargs:
|
||||
saturation = kwargs.get("saturation")
|
||||
if "hue" in kwargs:
|
||||
hue = kwargs.get("hue")
|
||||
[brightness, contrast, saturation, hue], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_random_color_adjust_param(brightness, "brightness")
|
||||
check_random_color_adjust_param(contrast, "contrast")
|
||||
check_random_color_adjust_param(saturation, "saturation")
|
||||
check_random_color_adjust_param(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False)
|
||||
|
||||
if brightness is not None:
|
||||
kwargs["brightness"] = check_random_color_adjust_param(brightness, "brightness")
|
||||
if contrast is not None:
|
||||
kwargs["contrast"] = check_random_color_adjust_param(contrast, "contrast")
|
||||
if saturation is not None:
|
||||
kwargs["saturation"] = check_random_color_adjust_param(saturation, "saturation")
|
||||
if hue is not None:
|
||||
kwargs["hue"] = check_random_color_adjust_param(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False)
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -402,38 +258,19 @@ def check_random_rotation(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 5 * [None])[:5]
|
||||
degrees, resample, expand, center, fill_value = args
|
||||
if "degrees" in kwargs:
|
||||
degrees = kwargs.get("degrees")
|
||||
if "resample" in kwargs:
|
||||
resample = kwargs.get("resample")
|
||||
if "expand" in kwargs:
|
||||
expand = kwargs.get("expand")
|
||||
if "center" in kwargs:
|
||||
center = kwargs.get("center")
|
||||
if "fill_value" in kwargs:
|
||||
fill_value = kwargs.get("fill_value")
|
||||
|
||||
if degrees is None:
|
||||
raise ValueError("degrees is not provided.")
|
||||
degrees = check_degrees(degrees)
|
||||
kwargs["degrees"] = degrees
|
||||
[degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_degrees(degrees)
|
||||
|
||||
if resample is not None:
|
||||
check_inter_mode(resample)
|
||||
kwargs["resample"] = resample
|
||||
type_check(resample, (Inter,), "resample")
|
||||
if expand is not None:
|
||||
check_bool(expand)
|
||||
kwargs["expand"] = expand
|
||||
type_check(expand, (bool,), "expand")
|
||||
if center is not None:
|
||||
check_2tuple(center)
|
||||
kwargs["center"] = center
|
||||
check_2tuple(center, "center")
|
||||
if fill_value is not None:
|
||||
fill_value = check_fill_value(fill_value)
|
||||
kwargs["fill_value"] = fill_value
|
||||
check_fill_value(fill_value)
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -443,16 +280,11 @@ def check_transforms_list(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
transforms = (list(args) + [None])[0]
|
||||
if "transforms" in kwargs:
|
||||
transforms = kwargs.get("transforms")
|
||||
if transforms is None:
|
||||
raise ValueError("transforms is not provided.")
|
||||
[transforms], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
check_list(transforms)
|
||||
kwargs["transforms"] = transforms
|
||||
type_check(transforms, (list,), "transforms")
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -462,21 +294,14 @@ def check_random_apply(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
transforms, prob = (list(args) + 2 * [None])[:2]
|
||||
if "transforms" in kwargs:
|
||||
transforms = kwargs.get("transforms")
|
||||
if transforms is None:
|
||||
raise ValueError("transforms is not provided.")
|
||||
check_list(transforms)
|
||||
kwargs["transforms"] = transforms
|
||||
[transforms, prob], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(transforms, (list,), "transforms")
|
||||
|
||||
if "prob" in kwargs:
|
||||
prob = kwargs.get("prob")
|
||||
if prob is not None:
|
||||
check_value(prob, [0., 1.])
|
||||
kwargs["prob"] = prob
|
||||
type_check(prob, (float, int,), "prob")
|
||||
check_value(prob, [0., 1.], "prob")
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -486,23 +311,13 @@ def check_ten_crop(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 2 * [None])[:2]
|
||||
size, use_vertical_flip = args
|
||||
if "size" in kwargs:
|
||||
size = kwargs.get("size")
|
||||
if "use_vertical_flip" in kwargs:
|
||||
use_vertical_flip = kwargs.get("use_vertical_flip")
|
||||
|
||||
if size is None:
|
||||
raise ValueError("size is not provided.")
|
||||
size = check_crop_size(size)
|
||||
kwargs["size"] = size
|
||||
[size, use_vertical_flip], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_crop_size(size)
|
||||
|
||||
if use_vertical_flip is not None:
|
||||
check_bool(use_vertical_flip)
|
||||
kwargs["use_vertical_flip"] = use_vertical_flip
|
||||
type_check(use_vertical_flip, (bool,), "use_vertical_flip")
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -512,16 +327,13 @@ def check_num_channels(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
num_output_channels = (list(args) + [None])[0]
|
||||
if "num_output_channels" in kwargs:
|
||||
num_output_channels = kwargs.get("num_output_channels")
|
||||
[num_output_channels], _ = parse_user_args(method, *args, **kwargs)
|
||||
if num_output_channels is not None:
|
||||
if num_output_channels not in (1, 3):
|
||||
raise ValueError("Number of channels of the output grayscale image"
|
||||
"should be either 1 or 3. Got {0}".format(num_output_channels))
|
||||
kwargs["num_output_channels"] = num_output_channels
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -531,28 +343,12 @@ def check_pad(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 3 * [None])[:3]
|
||||
padding, fill_value, padding_mode = args
|
||||
if "padding" in kwargs:
|
||||
padding = kwargs.get("padding")
|
||||
if "fill_value" in kwargs:
|
||||
fill_value = kwargs.get("fill_value")
|
||||
if "padding_mode" in kwargs:
|
||||
padding_mode = kwargs.get("padding_mode")
|
||||
[padding, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_padding(padding)
|
||||
check_fill_value(fill_value)
|
||||
type_check(padding_mode, (Border,), "padding_mode")
|
||||
|
||||
if padding is None:
|
||||
raise ValueError("padding is not provided.")
|
||||
padding = check_padding(padding)
|
||||
kwargs["padding"] = padding
|
||||
|
||||
if fill_value is not None:
|
||||
fill_value = check_fill_value(fill_value)
|
||||
kwargs["fill_value"] = fill_value
|
||||
if padding_mode is not None:
|
||||
check_border_type(padding_mode)
|
||||
kwargs["padding_mode"] = padding_mode
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -562,26 +358,13 @@ def check_random_perspective(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 3 * [None])[:3]
|
||||
distortion_scale, prob, interpolation = args
|
||||
if "distortion_scale" in kwargs:
|
||||
distortion_scale = kwargs.get("distortion_scale")
|
||||
if "prob" in kwargs:
|
||||
prob = kwargs.get("prob")
|
||||
if "interpolation" in kwargs:
|
||||
interpolation = kwargs.get("interpolation")
|
||||
[distortion_scale, prob, interpolation], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if distortion_scale is not None:
|
||||
check_value(distortion_scale, [0., 1.])
|
||||
kwargs["distortion_scale"] = distortion_scale
|
||||
if prob is not None:
|
||||
check_value(prob, [0., 1.])
|
||||
kwargs["prob"] = prob
|
||||
if interpolation is not None:
|
||||
check_inter_mode(interpolation)
|
||||
kwargs["interpolation"] = interpolation
|
||||
check_value(distortion_scale, [0., 1.], "distortion_scale")
|
||||
check_value(prob, [0., 1.], "prob")
|
||||
type_check(interpolation, (Inter,), "interpolation")
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -591,28 +374,13 @@ def check_mix_up(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 3 * [None])[:3]
|
||||
batch_size, alpha, is_single = args
|
||||
if "batch_size" in kwargs:
|
||||
batch_size = kwargs.get("batch_size")
|
||||
if "alpha" in kwargs:
|
||||
alpha = kwargs.get("alpha")
|
||||
if "is_single" in kwargs:
|
||||
is_single = kwargs.get("is_single")
|
||||
[batch_size, alpha, is_single], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if batch_size is None:
|
||||
raise ValueError("batch_size")
|
||||
check_pos_int32(batch_size)
|
||||
kwargs["batch_size"] = batch_size
|
||||
if alpha is None:
|
||||
raise ValueError("alpha")
|
||||
check_positive(alpha)
|
||||
kwargs["alpha"] = alpha
|
||||
if is_single is not None:
|
||||
check_type(is_single, bool)
|
||||
kwargs["is_single"] = is_single
|
||||
check_value(batch_size, (1, FLOAT_MAX_INTEGER))
|
||||
check_positive(alpha, "alpha")
|
||||
type_check(is_single, (bool,), "is_single")
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -622,41 +390,16 @@ def check_random_erasing(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 6 * [None])[:6]
|
||||
prob, scale, ratio, value, inplace, max_attempts = args
|
||||
if "prob" in kwargs:
|
||||
prob = kwargs.get("prob")
|
||||
if "scale" in kwargs:
|
||||
scale = kwargs.get("scale")
|
||||
if "ratio" in kwargs:
|
||||
ratio = kwargs.get("ratio")
|
||||
if "value" in kwargs:
|
||||
value = kwargs.get("value")
|
||||
if "inplace" in kwargs:
|
||||
inplace = kwargs.get("inplace")
|
||||
if "max_attempts" in kwargs:
|
||||
max_attempts = kwargs.get("max_attempts")
|
||||
[prob, scale, ratio, value, inplace, max_attempts], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if prob is not None:
|
||||
check_value(prob, [0., 1.])
|
||||
kwargs["prob"] = prob
|
||||
if scale is not None:
|
||||
check_value(prob, [0., 1.], "prob")
|
||||
check_range(scale, [0, FLOAT_MAX_INTEGER])
|
||||
kwargs["scale"] = scale
|
||||
if ratio is not None:
|
||||
check_range(ratio, [0, FLOAT_MAX_INTEGER])
|
||||
kwargs["ratio"] = ratio
|
||||
if value is not None:
|
||||
check_erasing_value(value)
|
||||
kwargs["value"] = value
|
||||
if inplace is not None:
|
||||
check_bool(inplace)
|
||||
kwargs["inplace"] = inplace
|
||||
if max_attempts is not None:
|
||||
check_pos_int32(max_attempts)
|
||||
kwargs["max_attempts"] = max_attempts
|
||||
type_check(inplace, (bool,), "inplace")
|
||||
check_value(max_attempts, (1, FLOAT_MAX_INTEGER))
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -666,23 +409,12 @@ def check_cutout(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 2 * [None])[:2]
|
||||
length, num_patches = args
|
||||
if "length" in kwargs:
|
||||
length = kwargs.get("length")
|
||||
if "num_patches" in kwargs:
|
||||
num_patches = kwargs.get("num_patches")
|
||||
[length, num_patches], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if length is None:
|
||||
raise ValueError("length")
|
||||
check_pos_int32(length)
|
||||
kwargs["length"] = length
|
||||
check_value(length, (1, FLOAT_MAX_INTEGER))
|
||||
check_value(num_patches, (1, FLOAT_MAX_INTEGER))
|
||||
|
||||
if num_patches is not None:
|
||||
check_pos_int32(num_patches)
|
||||
kwargs["num_patches"] = num_patches
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -692,17 +424,9 @@ def check_linear_transform(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 2 * [None])[:2]
|
||||
transformation_matrix, mean_vector = args
|
||||
if "transformation_matrix" in kwargs:
|
||||
transformation_matrix = kwargs.get("transformation_matrix")
|
||||
if "mean_vector" in kwargs:
|
||||
mean_vector = kwargs.get("mean_vector")
|
||||
|
||||
if transformation_matrix is None:
|
||||
raise ValueError("transformation_matrix is not provided.")
|
||||
if mean_vector is None:
|
||||
raise ValueError("mean_vector is not provided.")
|
||||
[transformation_matrix, mean_vector], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(transformation_matrix, (np.ndarray,), "transformation_matrix")
|
||||
type_check(mean_vector, (np.ndarray,), "mean_vector")
|
||||
|
||||
if transformation_matrix.shape[0] != transformation_matrix.shape[1]:
|
||||
raise ValueError("transformation_matrix should be a square matrix. "
|
||||
|
@ -711,10 +435,7 @@ def check_linear_transform(method):
|
|||
raise ValueError("mean_vector length {0} should match either one dimension of the square"
|
||||
"transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
|
||||
|
||||
kwargs["transformation_matrix"] = transformation_matrix
|
||||
kwargs["mean_vector"] = mean_vector
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -724,67 +445,40 @@ def check_random_affine(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
args = (list(args) + 6 * [None])[:6]
|
||||
degrees, translate, scale, shear, resample, fill_value = args
|
||||
if "degrees" in kwargs:
|
||||
degrees = kwargs.get("degrees")
|
||||
if "translate" in kwargs:
|
||||
translate = kwargs.get("translate")
|
||||
if "scale" in kwargs:
|
||||
scale = kwargs.get("scale")
|
||||
if "shear" in kwargs:
|
||||
shear = kwargs.get("shear")
|
||||
if "resample" in kwargs:
|
||||
resample = kwargs.get("resample")
|
||||
if "fill_value" in kwargs:
|
||||
fill_value = kwargs.get("fill_value")
|
||||
|
||||
if degrees is None:
|
||||
raise ValueError("degrees is not provided.")
|
||||
degrees = check_degrees(degrees)
|
||||
kwargs["degrees"] = degrees
|
||||
[degrees, translate, scale, shear, resample, fill_value], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_degrees(degrees)
|
||||
|
||||
if translate is not None:
|
||||
if isinstance(translate, (tuple, list)) and len(translate) == 2:
|
||||
for t in translate:
|
||||
if t < 0.0 or t > 1.0:
|
||||
raise ValueError("translation values should be between 0 and 1")
|
||||
else:
|
||||
if type_check(translate, (list, tuple), "translate"):
|
||||
translate_names = ["translate_{0}".format(i) for i in range(len(translate))]
|
||||
type_check_list(translate, (int, float), translate_names)
|
||||
if len(translate) != 2:
|
||||
raise TypeError("translate should be a list or tuple of length 2.")
|
||||
kwargs["translate"] = translate
|
||||
for i, t in enumerate(translate):
|
||||
check_value(t, [0.0, 1.0], "translate at {0}".format(i))
|
||||
|
||||
if scale is not None:
|
||||
if isinstance(scale, (tuple, list)) and len(scale) == 2:
|
||||
for s in scale:
|
||||
if s <= 0:
|
||||
raise ValueError("scale values should be positive")
|
||||
type_check(scale, (tuple, list), "scale")
|
||||
if len(scale) == 2:
|
||||
for i, s in enumerate(scale):
|
||||
check_positive(s, "scale[{}]".format(i))
|
||||
else:
|
||||
raise TypeError("scale should be a list or tuple of length 2.")
|
||||
kwargs["scale"] = scale
|
||||
|
||||
if shear is not None:
|
||||
type_check(shear, (numbers.Number, tuple, list), "shear")
|
||||
if isinstance(shear, numbers.Number):
|
||||
if shear < 0:
|
||||
raise ValueError("If shear is a single number, it must be positive.")
|
||||
shear = (-1 * shear, shear)
|
||||
elif isinstance(shear, (tuple, list)) and (len(shear) == 2 or len(shear) == 4):
|
||||
# X-Axis shear with [min, max]
|
||||
if len(shear) == 2:
|
||||
shear = [shear[0], shear[1], 0., 0.]
|
||||
elif len(shear) == 4:
|
||||
shear = [s for s in shear]
|
||||
check_positive(shear, "shear")
|
||||
else:
|
||||
raise TypeError("shear should be a list or tuple and it must be of length 2 or 4.")
|
||||
kwargs["shear"] = shear
|
||||
if len(shear) not in (2, 4):
|
||||
raise TypeError("shear must be of length 2 or 4.")
|
||||
|
||||
type_check(resample, (Inter,), "resample")
|
||||
|
||||
if resample is not None:
|
||||
check_inter_mode(resample)
|
||||
kwargs["resample"] = resample
|
||||
if fill_value is not None:
|
||||
fill_value = check_fill_value(fill_value)
|
||||
kwargs["fill_value"] = fill_value
|
||||
check_fill_value(fill_value)
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -794,24 +488,11 @@ def check_rescale(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
rescale, shift = (list(args) + 2 * [None])[:2]
|
||||
if "rescale" in kwargs:
|
||||
rescale = kwargs.get("rescale")
|
||||
if "shift" in kwargs:
|
||||
shift = kwargs.get("shift")
|
||||
|
||||
if rescale is None:
|
||||
raise ValueError("rescale is not provided.")
|
||||
[rescale, shift], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_pos_float32(rescale)
|
||||
kwargs["rescale"] = rescale
|
||||
type_check(shift, (numbers.Number,), "shift")
|
||||
|
||||
if shift is None:
|
||||
raise ValueError("shift is not provided.")
|
||||
if not isinstance(shift, numbers.Number):
|
||||
raise TypeError("shift is not a number.")
|
||||
kwargs["shift"] = shift
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -821,33 +502,16 @@ def check_uniform_augment_cpp(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
operations, num_ops = (list(args) + 2 * [None])[:2]
|
||||
if "operations" in kwargs:
|
||||
operations = kwargs.get("operations")
|
||||
else:
|
||||
raise ValueError("operations list required")
|
||||
if "num_ops" in kwargs:
|
||||
num_ops = kwargs.get("num_ops")
|
||||
else:
|
||||
num_ops = 2
|
||||
[operations, num_ops], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(num_ops, (int,), "num_ops")
|
||||
check_positive(num_ops, "num_ops")
|
||||
|
||||
if not isinstance(num_ops, int):
|
||||
raise ValueError("Number of operations should be an integer.")
|
||||
|
||||
if num_ops <= 0:
|
||||
raise ValueError("num_ops should be greater than zero")
|
||||
if num_ops > len(operations):
|
||||
raise ValueError("num_ops is greater than operations list size")
|
||||
if not isinstance(operations, list):
|
||||
raise TypeError("operations is not a python list")
|
||||
for op in operations:
|
||||
if not isinstance(op, TensorOp):
|
||||
raise ValueError("operations list only accepts C++ operations.")
|
||||
tensor_ops = ["tensor_op_{0}".format(i) for i in range(len(operations))]
|
||||
type_check_list(operations, (TensorOp,), tensor_ops)
|
||||
|
||||
kwargs["num_ops"] = num_ops
|
||||
kwargs["operations"] = operations
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -857,23 +521,11 @@ def check_bounding_box_augment_cpp(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
transform, ratio = (list(args) + 2 * [None])[:2]
|
||||
if "transform" in kwargs:
|
||||
transform = kwargs.get("transform")
|
||||
if "ratio" in kwargs:
|
||||
ratio = kwargs.get("ratio")
|
||||
if not isinstance(ratio, float) and not isinstance(ratio, int):
|
||||
raise ValueError("Ratio should be an int or float.")
|
||||
if ratio is not None:
|
||||
check_value(ratio, [0., 1.])
|
||||
kwargs["ratio"] = ratio
|
||||
else:
|
||||
ratio = 0.3
|
||||
if not isinstance(transform, TensorOp):
|
||||
raise ValueError("Transform can only be a C++ operation.")
|
||||
kwargs["transform"] = transform
|
||||
kwargs["ratio"] = ratio
|
||||
return method(self, **kwargs)
|
||||
[transform, ratio], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(ratio, (float, int), "ratio")
|
||||
check_value(ratio, [0., 1.], "ratio")
|
||||
type_check(transform, (TensorOp,), "transform")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -883,29 +535,22 @@ def check_uniform_augment_py(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
transforms, num_ops = (list(args) + 2 * [None])[:2]
|
||||
if "transforms" in kwargs:
|
||||
transforms = kwargs.get("transforms")
|
||||
if transforms is None:
|
||||
raise ValueError("transforms is not provided.")
|
||||
[transforms, num_ops], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(transforms, (list,), "transforms")
|
||||
|
||||
if not transforms:
|
||||
raise ValueError("transforms list is empty.")
|
||||
check_list(transforms)
|
||||
|
||||
for transform in transforms:
|
||||
if isinstance(transform, TensorOp):
|
||||
raise ValueError("transform list only accepts Python operations.")
|
||||
kwargs["transforms"] = transforms
|
||||
|
||||
if "num_ops" in kwargs:
|
||||
num_ops = kwargs.get("num_ops")
|
||||
if num_ops is not None:
|
||||
check_type(num_ops, int)
|
||||
check_positive(num_ops)
|
||||
type_check(num_ops, (int,), "num_ops")
|
||||
check_positive(num_ops, "num_ops")
|
||||
if num_ops > len(transforms):
|
||||
raise ValueError("num_ops cannot be greater than the length of transforms list.")
|
||||
kwargs["num_ops"] = num_ops
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -915,22 +560,16 @@ def check_positive_degrees(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
degrees = (list(args) + [None])[0]
|
||||
if "degrees" in kwargs:
|
||||
degrees = kwargs.get("degrees")
|
||||
[degrees], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if degrees is not None:
|
||||
if isinstance(degrees, (list, tuple)):
|
||||
if len(degrees) != 2:
|
||||
raise ValueError("Degrees must be a sequence with length 2.")
|
||||
if degrees[0] < 0:
|
||||
raise ValueError("Degrees range must be non-negative.")
|
||||
check_positive(degrees[0], "degrees[0]")
|
||||
if degrees[0] > degrees[1]:
|
||||
raise ValueError("Degrees should be in (min,max) format. Got (max,min).")
|
||||
else:
|
||||
raise TypeError("Degrees must be a sequence in (min,max) format.")
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
@ -940,18 +579,12 @@ def check_compose_list(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
transforms = (list(args) + [None])[0]
|
||||
if "transforms" in kwargs:
|
||||
transforms = kwargs.get("transforms")
|
||||
if transforms is None:
|
||||
raise ValueError("transforms is not provided.")
|
||||
[transforms], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
type_check(transforms, (list,), transforms)
|
||||
if not transforms:
|
||||
raise ValueError("transforms list is empty.")
|
||||
if not isinstance(transforms, list):
|
||||
raise TypeError("transforms is not a python list")
|
||||
|
||||
kwargs["transforms"] = transforms
|
||||
|
||||
return method(self, **kwargs)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -15,13 +15,15 @@
|
|||
"""
|
||||
Testing the bounding box augment op in DE
|
||||
"""
|
||||
from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \
|
||||
config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5
|
||||
|
||||
import numpy as np
|
||||
import mindspore.log as logger
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
||||
|
||||
from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \
|
||||
config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
# updated VOC dataset with correct annotations
|
||||
|
@ -241,7 +243,7 @@ def test_bounding_box_augment_invalid_ratio_c():
|
|||
operations=[test_op]) # Add column for "annotation"
|
||||
except ValueError as error:
|
||||
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||
assert "Input is not" in str(error)
|
||||
assert "Input ratio is not within the required interval of (0.0 to 1.0)." in str(error)
|
||||
|
||||
|
||||
def test_bounding_box_augment_invalid_bounds_c():
|
||||
|
|
|
@ -17,6 +17,7 @@ import pytest
|
|||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
# generates 1 column [0], [0, 1], ..., [0, ..., n-1]
|
||||
def generate_sequential(n):
|
||||
for i in range(n):
|
||||
|
@ -99,12 +100,12 @@ def test_bucket_batch_invalid_input():
|
|||
with pytest.raises(TypeError) as info:
|
||||
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
|
||||
None, None, invalid_type_pad_to_bucket_boundary)
|
||||
assert "Wrong input type for pad_to_bucket_boundary, should be <class 'bool'>" in str(info.value)
|
||||
assert "Argument pad_to_bucket_boundary with value \"\" is not of type (<class \'bool\'>,)." in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
|
||||
None, None, False, invalid_type_drop_remainder)
|
||||
assert "Wrong input type for drop_remainder, should be <class 'bool'>" in str(info.value)
|
||||
assert "Argument drop_remainder with value \"\" is not of type (<class 'bool'>,)." in str(info.value)
|
||||
|
||||
|
||||
def test_bucket_batch_multi_bucket_no_padding():
|
||||
|
@ -272,7 +273,6 @@ def test_bucket_batch_default_pad():
|
|||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]]
|
||||
|
||||
|
||||
output = []
|
||||
for data in dataset.create_dict_iterator():
|
||||
output.append(data["col1"].tolist())
|
||||
|
|
|
@ -163,18 +163,11 @@ def test_concatenate_op_negative_axis():
|
|||
|
||||
|
||||
def test_concatenate_op_incorrect_input_dim():
|
||||
def gen():
|
||||
yield (np.array(["ss", "ad"], dtype='S'),)
|
||||
|
||||
prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S')
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
concatenate_op = data_trans.Concatenate(0, prepend_tensor)
|
||||
|
||||
data = data.map(input_columns=["col"], operations=concatenate_op)
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
for _ in data:
|
||||
pass
|
||||
assert "Only 1D tensors supported" in repr(error_info.value)
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
data_trans.Concatenate(0, prepend_tensor)
|
||||
assert "can only prepend 1D arrays." in repr(error_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -28,9 +28,9 @@ def test_exception_01():
|
|||
"""
|
||||
logger.info("test_exception_01")
|
||||
data = ds.TFRecordDataset(DATA_DIR, columns_list=["image"])
|
||||
with pytest.raises(ValueError) as info:
|
||||
data = data.map(input_columns=["image"], operations=vision.Resize(100, 100))
|
||||
assert "Invalid interpolation mode." in str(info.value)
|
||||
with pytest.raises(TypeError) as info:
|
||||
data.map(input_columns=["image"], operations=vision.Resize(100, 100))
|
||||
assert "Argument interpolation with value 100 is not of type (<enum 'Inter'>,)" in str(info.value)
|
||||
|
||||
|
||||
def test_exception_02():
|
||||
|
@ -40,8 +40,8 @@ def test_exception_02():
|
|||
logger.info("test_exception_02")
|
||||
num_samples = -1
|
||||
with pytest.raises(ValueError) as info:
|
||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
||||
assert "num_samples cannot be less than 0" in str(info.value)
|
||||
ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
||||
assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(info.value)
|
||||
|
||||
num_samples = 1
|
||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
||||
|
|
|
@ -23,7 +23,8 @@ import mindspore.dataset.text as text
|
|||
def test_demo_basic_from_dataset():
|
||||
""" this is a tutorial on how from_dataset should be used in a normal use case"""
|
||||
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
|
||||
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"],
|
||||
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None,
|
||||
special_tokens=["<pad>", "<unk>"],
|
||||
special_first=True)
|
||||
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
||||
res = []
|
||||
|
@ -127,15 +128,16 @@ def test_from_dataset_exceptions():
|
|||
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
|
||||
vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k)
|
||||
assert isinstance(vocab.text.Vocab)
|
||||
except ValueError as e:
|
||||
except (TypeError, ValueError, RuntimeError) as e:
|
||||
assert s in str(e), str(e)
|
||||
|
||||
test_config("text", (), 1, "freq_range needs to be either None or a tuple of 2 integers")
|
||||
test_config("text", (2, 3), 1.2345, "top_k needs to be a positive integer")
|
||||
test_config(23, (2, 3), 1.2345, "columns need to be a list of strings")
|
||||
test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b")
|
||||
test_config("text", (2, 3), 0, "top_k needs to be a positive integer")
|
||||
test_config([123], (2, 3), 0, "columns need to be a list of strings")
|
||||
test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.")
|
||||
test_config("text", (2, 3), 1.2345,
|
||||
"Argument top_k with value 1.2345 is not of type (<class 'int'>, <class 'NoneType'>)")
|
||||
test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (<class 'str'>,)")
|
||||
test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)")
|
||||
test_config("text", (2, 3), 0, "top_k needs to be positive number")
|
||||
test_config([123], (2, 3), 0, "top_k needs to be positive number")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -73,6 +73,7 @@ def test_linear_transformation_op(plot=False):
|
|||
if plot:
|
||||
visualize_list(image, image_transformed)
|
||||
|
||||
|
||||
def test_linear_transformation_md5():
|
||||
"""
|
||||
Test LinearTransformation op: valid params (transformation_matrix, mean_vector)
|
||||
|
@ -102,6 +103,7 @@ def test_linear_transformation_md5():
|
|||
filename = "linear_transformation_01_result.npz"
|
||||
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_linear_transformation_exception_01():
|
||||
"""
|
||||
Test LinearTransformation op: transformation_matrix is not provided
|
||||
|
@ -126,9 +128,10 @@ def test_linear_transformation_exception_01():
|
|||
]
|
||||
transform = py_vision.ComposeOp(transforms)
|
||||
data1 = data1.map(input_columns=["image"], operations=transform())
|
||||
except ValueError as e:
|
||||
except TypeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "not provided" in str(e)
|
||||
assert "Argument transformation_matrix with value None is not of type (<class 'numpy.ndarray'>,)" in str(e)
|
||||
|
||||
|
||||
def test_linear_transformation_exception_02():
|
||||
"""
|
||||
|
@ -154,9 +157,10 @@ def test_linear_transformation_exception_02():
|
|||
]
|
||||
transform = py_vision.ComposeOp(transforms)
|
||||
data1 = data1.map(input_columns=["image"], operations=transform())
|
||||
except ValueError as e:
|
||||
except TypeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "not provided" in str(e)
|
||||
assert "Argument mean_vector with value None is not of type (<class 'numpy.ndarray'>,)" in str(e)
|
||||
|
||||
|
||||
def test_linear_transformation_exception_03():
|
||||
"""
|
||||
|
@ -187,6 +191,7 @@ def test_linear_transformation_exception_03():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "square matrix" in str(e)
|
||||
|
||||
|
||||
def test_linear_transformation_exception_04():
|
||||
"""
|
||||
Test LinearTransformation op: mean_vector does not match dimension of transformation_matrix
|
||||
|
@ -216,6 +221,7 @@ def test_linear_transformation_exception_04():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "should match" in str(e)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_linear_transformation_op(plot=True)
|
||||
test_linear_transformation_md5()
|
||||
|
|
|
@ -184,24 +184,26 @@ def test_minddataset_invalidate_num_shards():
|
|||
create_cv_mindrecord(1)
|
||||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(Exception, match="shard_id is invalid, "):
|
||||
with pytest.raises(Exception) as error_info:
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info)
|
||||
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
||||
|
||||
def test_minddataset_invalidate_shard_id():
|
||||
create_cv_mindrecord(1)
|
||||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(Exception, match="shard_id is invalid, "):
|
||||
with pytest.raises(Exception) as error_info:
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info)
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
||||
|
@ -210,17 +212,19 @@ def test_minddataset_shard_id_bigger_than_num_shard():
|
|||
create_cv_mindrecord(1)
|
||||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(Exception, match="shard_id is invalid, "):
|
||||
with pytest.raises(Exception) as error_info:
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info)
|
||||
|
||||
with pytest.raises(Exception, match="shard_id is invalid, "):
|
||||
with pytest.raises(Exception) as error_info:
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info)
|
||||
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
"""
|
||||
Testing Ngram in mindspore.dataset
|
||||
"""
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.text as text
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_multiple_ngrams():
|
||||
|
@ -61,7 +61,7 @@ def test_simple_ngram():
|
|||
yield (np.array(line.split(" "), dtype='S'),)
|
||||
|
||||
dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"])
|
||||
dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=None))
|
||||
dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=" "))
|
||||
|
||||
i = 0
|
||||
for data in dataset.create_dict_iterator():
|
||||
|
@ -72,7 +72,7 @@ def test_simple_ngram():
|
|||
def test_corner_cases():
|
||||
""" testing various corner cases and exceptions"""
|
||||
|
||||
def test_config(input_line, output_line, n, l_pad=None, r_pad=None, sep=None):
|
||||
def test_config(input_line, output_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "):
|
||||
def gen(texts):
|
||||
yield (np.array(texts.split(" "), dtype='S'),)
|
||||
|
||||
|
@ -93,7 +93,7 @@ def test_corner_cases():
|
|||
try:
|
||||
test_config("Yours to Discover", "", [0, [1]])
|
||||
except Exception as e:
|
||||
assert "ngram needs to be a positive number" in str(e)
|
||||
assert "Argument gram[1] with value [1] is not of type (<class 'int'>,)" in str(e)
|
||||
# test empty n
|
||||
try:
|
||||
test_config("Yours to Discover", "", [])
|
||||
|
|
|
@ -279,7 +279,7 @@ def test_normalize_exception_invalid_range_py():
|
|||
_ = py_vision.Normalize([0.75, 1.25, 0.5], [0.1, 0.18, 1.32])
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input is not within the required range" in str(e)
|
||||
assert "Input mean_value is not within the required interval of (0.0 to 1.0)." in str(e)
|
||||
|
||||
|
||||
def test_normalize_grayscale_md5_01():
|
||||
|
|
|
@ -61,6 +61,10 @@ def test_pad_end_exceptions():
|
|||
pad_compare([3, 4, 5], ["2"], 1, [])
|
||||
assert "a value in the list is not an integer." in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
pad_compare([1, 2], 3, -1, [1, 2, -1])
|
||||
assert "Argument pad_end with value 3 is not of type (<class 'list'>,)" in str(info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pad_end_basics()
|
||||
|
|
|
@ -103,7 +103,7 @@ def test_random_affine_exception_negative_degrees():
|
|||
_ = py_vision.RandomAffine(degrees=-15)
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "If degrees is a single number, it cannot be negative."
|
||||
assert str(e) == "Input degrees is not within the required interval of (0 to inf)."
|
||||
|
||||
|
||||
def test_random_affine_exception_translation_range():
|
||||
|
@ -115,7 +115,7 @@ def test_random_affine_exception_translation_range():
|
|||
_ = py_vision.RandomAffine(degrees=15, translate=(0.1, 1.5))
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "translation values should be between 0 and 1"
|
||||
assert str(e) == "Input translate at 1 is not within the required interval of (0.0 to 1.0)."
|
||||
|
||||
|
||||
def test_random_affine_exception_scale_value():
|
||||
|
@ -127,7 +127,7 @@ def test_random_affine_exception_scale_value():
|
|||
_ = py_vision.RandomAffine(degrees=15, scale=(0.0, 1.1))
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "scale values should be positive"
|
||||
assert str(e) == "Input scale[0] must be greater than 0."
|
||||
|
||||
|
||||
def test_random_affine_exception_shear_value():
|
||||
|
@ -139,7 +139,7 @@ def test_random_affine_exception_shear_value():
|
|||
_ = py_vision.RandomAffine(degrees=15, shear=-5)
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "If shear is a single number, it must be positive."
|
||||
assert str(e) == "Input shear must be greater than 0."
|
||||
|
||||
|
||||
def test_random_affine_exception_degrees_size():
|
||||
|
@ -165,7 +165,9 @@ def test_random_affine_exception_translate_size():
|
|||
_ = py_vision.RandomAffine(degrees=15, translate=(0.1))
|
||||
except TypeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "translate should be a list or tuple of length 2."
|
||||
assert str(
|
||||
e) == "Argument translate with value 0.1 is not of type (<class 'list'>," \
|
||||
" <class 'tuple'>)."
|
||||
|
||||
|
||||
def test_random_affine_exception_scale_size():
|
||||
|
@ -178,7 +180,8 @@ def test_random_affine_exception_scale_size():
|
|||
_ = py_vision.RandomAffine(degrees=15, scale=(0.5))
|
||||
except TypeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "scale should be a list or tuple of length 2."
|
||||
assert str(e) == "Argument scale with value 0.5 is not of type (<class 'tuple'>," \
|
||||
" <class 'list'>)."
|
||||
|
||||
|
||||
def test_random_affine_exception_shear_size():
|
||||
|
@ -191,7 +194,7 @@ def test_random_affine_exception_shear_size():
|
|||
_ = py_vision.RandomAffine(degrees=15, shear=(-5, 5, 10))
|
||||
except TypeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "shear should be a list or tuple and it must be of length 2 or 4."
|
||||
assert str(e) == "shear must be of length 2 or 4."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -97,7 +97,7 @@ def test_random_color_md5():
|
|||
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms = F.ComposeOp([F.Decode(),
|
||||
F.RandomColor((0.5, 1.5)),
|
||||
F.RandomColor((0.1, 1.9)),
|
||||
F.ToTensor()])
|
||||
|
||||
data = data.map(input_columns="image", operations=transforms())
|
||||
|
|
|
@ -232,7 +232,7 @@ def test_random_crop_and_resize_04_c():
|
|||
data = data.map(input_columns=["image"], operations=random_crop_and_resize_op)
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input range is not valid" in str(e)
|
||||
assert "Input is not within the required interval of (0 to 16777216)." in str(e)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_04_py():
|
||||
|
@ -255,7 +255,7 @@ def test_random_crop_and_resize_04_py():
|
|||
data = data.map(input_columns=["image"], operations=transform())
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input range is not valid" in str(e)
|
||||
assert "Input is not within the required interval of (0 to 16777216)." in str(e)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_05_c():
|
||||
|
@ -275,7 +275,7 @@ def test_random_crop_and_resize_05_c():
|
|||
data = data.map(input_columns=["image"], operations=random_crop_and_resize_op)
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input range is not valid" in str(e)
|
||||
assert "Input is not within the required interval of (0 to 16777216)." in str(e)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_05_py():
|
||||
|
@ -298,7 +298,7 @@ def test_random_crop_and_resize_05_py():
|
|||
data = data.map(input_columns=["image"], operations=transform())
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input range is not valid" in str(e)
|
||||
assert "Input is not within the required interval of (0 to 16777216)." in str(e)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_comp(plot=False):
|
||||
|
|
|
@ -159,7 +159,7 @@ def test_random_resized_crop_with_bbox_op_invalid_c():
|
|||
|
||||
except ValueError as err:
|
||||
logger.info("Got an exception in DE: {}".format(str(err)))
|
||||
assert "Input range is not valid" in str(err)
|
||||
assert "Input is not within the required interval of (0 to 16777216)." in str(err)
|
||||
|
||||
|
||||
def test_random_resized_crop_with_bbox_op_invalid2_c():
|
||||
|
@ -185,7 +185,7 @@ def test_random_resized_crop_with_bbox_op_invalid2_c():
|
|||
|
||||
except ValueError as err:
|
||||
logger.info("Got an exception in DE: {}".format(str(err)))
|
||||
assert "Input range is not valid" in str(err)
|
||||
assert "Input is not within the required interval of (0 to 16777216)." in str(err)
|
||||
|
||||
|
||||
def test_random_resized_crop_with_bbox_op_bad_c():
|
||||
|
|
|
@ -179,7 +179,7 @@ def test_random_grayscale_invalid_param():
|
|||
data = data.map(input_columns=["image"], operations=transform())
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input is not within the required range" in str(e)
|
||||
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_random_grayscale_valid_prob(True)
|
||||
|
|
|
@ -141,7 +141,7 @@ def test_random_horizontal_invalid_prob_c():
|
|||
data = data.map(input_columns=["image"], operations=random_horizontal_op)
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input is not" in str(e)
|
||||
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
|
||||
|
||||
|
||||
def test_random_horizontal_invalid_prob_py():
|
||||
|
@ -164,7 +164,7 @@ def test_random_horizontal_invalid_prob_py():
|
|||
data = data.map(input_columns=["image"], operations=transform())
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input is not" in str(e)
|
||||
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
|
||||
|
||||
|
||||
def test_random_horizontal_comp(plot=False):
|
||||
|
|
|
@ -190,7 +190,7 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c():
|
|||
operations=[test_op]) # Add column for "annotation"
|
||||
except ValueError as error:
|
||||
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||
assert "Input is not" in str(error)
|
||||
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(error)
|
||||
|
||||
|
||||
def test_random_horizontal_flip_with_bbox_invalid_bounds_c():
|
||||
|
|
|
@ -107,7 +107,7 @@ def test_random_perspective_exception_distortion_scale_range():
|
|||
_ = py_vision.RandomPerspective(distortion_scale=1.5)
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "Input is not within the required range"
|
||||
assert str(e) == "Input distortion_scale is not within the required interval of (0.0 to 1.0)."
|
||||
|
||||
|
||||
def test_random_perspective_exception_prob_range():
|
||||
|
@ -119,7 +119,7 @@ def test_random_perspective_exception_prob_range():
|
|||
_ = py_vision.RandomPerspective(prob=1.2)
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "Input is not within the required range"
|
||||
assert str(e) == "Input prob is not within the required interval of (0.0 to 1.0)."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -163,7 +163,7 @@ def test_random_resize_with_bbox_op_invalid_c():
|
|||
|
||||
except ValueError as err:
|
||||
logger.info("Got an exception in DE: {}".format(str(err)))
|
||||
assert "Input is not" in str(err)
|
||||
assert "Input is not within the required interval of (1 to 16777216)." in str(err)
|
||||
|
||||
try:
|
||||
# one of the size values is zero
|
||||
|
@ -171,7 +171,7 @@ def test_random_resize_with_bbox_op_invalid_c():
|
|||
|
||||
except ValueError as err:
|
||||
logger.info("Got an exception in DE: {}".format(str(err)))
|
||||
assert "Input is not" in str(err)
|
||||
assert "Input size at dim 0 is not within the required interval of (1 to 2147483647)." in str(err)
|
||||
|
||||
try:
|
||||
# negative value for resize
|
||||
|
@ -179,7 +179,7 @@ def test_random_resize_with_bbox_op_invalid_c():
|
|||
|
||||
except ValueError as err:
|
||||
logger.info("Got an exception in DE: {}".format(str(err)))
|
||||
assert "Input is not" in str(err)
|
||||
assert "Input is not within the required interval of (1 to 16777216)." in str(err)
|
||||
|
||||
try:
|
||||
# invalid input shape
|
||||
|
|
|
@ -97,7 +97,7 @@ def test_random_sharpness_md5():
|
|||
# define map operations
|
||||
transforms = [
|
||||
F.Decode(),
|
||||
F.RandomSharpness((0.5, 1.5)),
|
||||
F.RandomSharpness((0.1, 1.9)),
|
||||
F.ToTensor()
|
||||
]
|
||||
transform = F.ComposeOp(transforms)
|
||||
|
|
|
@ -141,7 +141,7 @@ def test_random_vertical_invalid_prob_c():
|
|||
data = data.map(input_columns=["image"], operations=random_horizontal_op)
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input is not" in str(e)
|
||||
assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e)
|
||||
|
||||
|
||||
def test_random_vertical_invalid_prob_py():
|
||||
|
@ -163,7 +163,7 @@ def test_random_vertical_invalid_prob_py():
|
|||
data = data.map(input_columns=["image"], operations=transform())
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input is not" in str(e)
|
||||
assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e)
|
||||
|
||||
|
||||
def test_random_vertical_comp(plot=False):
|
||||
|
|
|
@ -191,7 +191,7 @@ def test_random_vertical_flip_with_bbox_op_invalid_c():
|
|||
|
||||
except ValueError as err:
|
||||
logger.info("Got an exception in DE: {}".format(str(err)))
|
||||
assert "Input is not" in str(err)
|
||||
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(err)
|
||||
|
||||
|
||||
def test_random_vertical_flip_with_bbox_op_bad_c():
|
||||
|
|
|
@ -150,7 +150,7 @@ def test_resize_with_bbox_op_invalid_c():
|
|||
# invalid interpolation value
|
||||
c_vision.ResizeWithBBox(400, interpolation="invalid")
|
||||
|
||||
except ValueError as err:
|
||||
except TypeError as err:
|
||||
logger.info("Got an exception in DE: {}".format(str(err)))
|
||||
assert "interpolation" in str(err)
|
||||
|
||||
|
|
|
@ -154,7 +154,7 @@ def test_shuffle_exception_01():
|
|||
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "buffer_size" in str(e)
|
||||
assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
|
||||
|
||||
|
||||
def test_shuffle_exception_02():
|
||||
|
@ -172,7 +172,7 @@ def test_shuffle_exception_02():
|
|||
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "buffer_size" in str(e)
|
||||
assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
|
||||
|
||||
|
||||
def test_shuffle_exception_03():
|
||||
|
@ -190,7 +190,7 @@ def test_shuffle_exception_03():
|
|||
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "buffer_size" in str(e)
|
||||
assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
|
||||
|
||||
|
||||
def test_shuffle_exception_05():
|
||||
|
|
|
@ -144,7 +144,7 @@ def test_ten_crop_invalid_size_error_msg():
|
|||
vision.TenCrop(0),
|
||||
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
|
||||
]
|
||||
error_msg = "Input is not within the required range"
|
||||
error_msg = "Input is not within the required interval of (1 to 16777216)."
|
||||
assert error_msg == str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
|
|
|
@ -169,7 +169,9 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2):
|
|||
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "operations" in str(e)
|
||||
assert "Argument tensor_op_5 with value" \
|
||||
" <mindspore.dataset.transforms.vision.py_transforms.Invert" in str(e)
|
||||
assert "is not of type (<class 'mindspore._c_dataengine.TensorOp'>,)" in str(e)
|
||||
|
||||
|
||||
def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
|
||||
|
@ -209,7 +211,7 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
|
|||
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "num_ops" in str(e)
|
||||
assert "Input num_ops must be greater than 0" in str(e)
|
||||
|
||||
|
||||
def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
|
||||
|
@ -229,7 +231,7 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
|
|||
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "integer" in str(e)
|
||||
assert "Argument num_ops with value 2.5 is not of type (<class 'int'>,)" in str(e)
|
||||
|
||||
|
||||
def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
|
||||
|
|
|
@ -321,7 +321,8 @@ def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=
|
|||
|
||||
if len(orig) > plot_rows:
|
||||
# Create batches of required size and add remainder to last batch
|
||||
orig = np.split(orig[:split_point], batch_size) + ([orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added
|
||||
orig = np.split(orig[:split_point], batch_size) + (
|
||||
[orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added
|
||||
aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else [])
|
||||
else:
|
||||
orig = [orig]
|
||||
|
@ -336,7 +337,8 @@ def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=
|
|||
|
||||
for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])):
|
||||
cur_ix = base_ix + x
|
||||
(axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) # select plotting axes based on number of image rows on plot - else case when 1 row
|
||||
# select plotting axes based on number of image rows on plot - else case when 1 row
|
||||
(axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1])
|
||||
|
||||
axA.imshow(dataA["image"])
|
||||
add_bounding_boxes(axA, dataA[annot_name])
|
||||
|
|
Loading…
Reference in New Issue