forked from mindspore-Ecosystem/mindspore
fix validators
fixed random_apply tests fix validators fixed random_apply tests fix engine validation
This commit is contained in:
parent
915ddd25dd
commit
6c37ea3be0
|
@ -0,0 +1,342 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
General Validators.
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from ..engine import samplers
|
||||||
|
|
||||||
|
# POS_INT_MIN is used to limit values from starting from 0
|
||||||
|
POS_INT_MIN = 1
|
||||||
|
UINT8_MAX = 255
|
||||||
|
UINT8_MIN = 0
|
||||||
|
UINT32_MAX = 4294967295
|
||||||
|
UINT32_MIN = 0
|
||||||
|
UINT64_MAX = 18446744073709551615
|
||||||
|
UINT64_MIN = 0
|
||||||
|
INT32_MAX = 2147483647
|
||||||
|
INT32_MIN = -2147483648
|
||||||
|
INT64_MAX = 9223372036854775807
|
||||||
|
INT64_MIN = -9223372036854775808
|
||||||
|
FLOAT_MAX_INTEGER = 16777216
|
||||||
|
FLOAT_MIN_INTEGER = -16777216
|
||||||
|
DOUBLE_MAX_INTEGER = 9007199254740992
|
||||||
|
DOUBLE_MIN_INTEGER = -9007199254740992
|
||||||
|
|
||||||
|
valid_detype = [
|
||||||
|
"bool", "int8", "int16", "int32", "int64", "uint8", "uint16",
|
||||||
|
"uint32", "uint64", "float16", "float32", "float64", "string"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def pad_arg_name(arg_name):
|
||||||
|
if arg_name != "":
|
||||||
|
arg_name = arg_name + " "
|
||||||
|
return arg_name
|
||||||
|
|
||||||
|
|
||||||
|
def check_value(value, valid_range, arg_name=""):
|
||||||
|
arg_name = pad_arg_name(arg_name)
|
||||||
|
if value < valid_range[0] or value > valid_range[1]:
|
||||||
|
raise ValueError(
|
||||||
|
"Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0],
|
||||||
|
valid_range[1]))
|
||||||
|
|
||||||
|
|
||||||
|
def check_range(values, valid_range, arg_name=""):
|
||||||
|
arg_name = pad_arg_name(arg_name)
|
||||||
|
if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]:
|
||||||
|
raise ValueError(
|
||||||
|
"Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0],
|
||||||
|
valid_range[1]))
|
||||||
|
|
||||||
|
|
||||||
|
def check_positive(value, arg_name=""):
|
||||||
|
arg_name = pad_arg_name(arg_name)
|
||||||
|
if value <= 0:
|
||||||
|
raise ValueError("Input {0}must be greater than 0.".format(arg_name))
|
||||||
|
|
||||||
|
|
||||||
|
def check_positive_float(value, arg_name=""):
|
||||||
|
arg_name = pad_arg_name(arg_name)
|
||||||
|
type_check(value, (float,), arg_name)
|
||||||
|
check_positive(value, arg_name)
|
||||||
|
|
||||||
|
|
||||||
|
def check_2tuple(value, arg_name=""):
|
||||||
|
if not (isinstance(value, tuple) and len(value) == 2):
|
||||||
|
raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name))
|
||||||
|
|
||||||
|
|
||||||
|
def check_uint8(value, arg_name=""):
|
||||||
|
type_check(value, (int,), arg_name)
|
||||||
|
check_value(value, [UINT8_MIN, UINT8_MAX])
|
||||||
|
|
||||||
|
|
||||||
|
def check_uint32(value, arg_name=""):
|
||||||
|
type_check(value, (int,), arg_name)
|
||||||
|
check_value(value, [UINT32_MIN, UINT32_MAX])
|
||||||
|
|
||||||
|
|
||||||
|
def check_pos_int32(value, arg_name=""):
|
||||||
|
type_check(value, (int,), arg_name)
|
||||||
|
check_value(value, [POS_INT_MIN, INT32_MAX])
|
||||||
|
|
||||||
|
|
||||||
|
def check_uint64(value, arg_name=""):
|
||||||
|
type_check(value, (int,), arg_name)
|
||||||
|
check_value(value, [UINT64_MIN, UINT64_MAX])
|
||||||
|
|
||||||
|
|
||||||
|
def check_pos_int64(value, arg_name=""):
|
||||||
|
type_check(value, (int,), arg_name)
|
||||||
|
check_value(value, [UINT64_MIN, INT64_MAX])
|
||||||
|
|
||||||
|
|
||||||
|
def check_pos_float32(value, arg_name=""):
|
||||||
|
check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name)
|
||||||
|
|
||||||
|
|
||||||
|
def check_pos_float64(value, arg_name=""):
|
||||||
|
check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER], arg_name)
|
||||||
|
|
||||||
|
|
||||||
|
def check_valid_detype(type_):
|
||||||
|
if type_ not in valid_detype:
|
||||||
|
raise ValueError("Unknown column type")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def check_columns(columns, name):
|
||||||
|
type_check(columns, (list, str), name)
|
||||||
|
if isinstance(columns, list):
|
||||||
|
if not columns:
|
||||||
|
raise ValueError("Column names should not be empty")
|
||||||
|
col_names = ["col_{0}".format(i) for i in range(len(columns))]
|
||||||
|
type_check_list(columns, (str,), col_names)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_user_args(method, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Parse user arguments in a function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method (method): a callable function
|
||||||
|
*args: user passed args
|
||||||
|
**kwargs: user passed kwargs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
user_filled_args (list): values of what the user passed in for the arguments,
|
||||||
|
ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed.
|
||||||
|
"""
|
||||||
|
sig = inspect.signature(method)
|
||||||
|
if 'self' in sig.parameters or 'cls' in sig.parameters:
|
||||||
|
ba = sig.bind(method, *args, **kwargs)
|
||||||
|
ba.apply_defaults()
|
||||||
|
params = list(sig.parameters.keys())[1:]
|
||||||
|
else:
|
||||||
|
ba = sig.bind(*args, **kwargs)
|
||||||
|
ba.apply_defaults()
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
user_filled_args = [ba.arguments.get(arg_value) for arg_value in params]
|
||||||
|
return user_filled_args, ba.arguments
|
||||||
|
|
||||||
|
|
||||||
|
def type_check_list(args, types, arg_names):
|
||||||
|
"""
|
||||||
|
Check the type of each parameter in the list
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (list, tuple): a list or tuple of any variable
|
||||||
|
types (tuple): tuple of all valid types for arg
|
||||||
|
arg_names (list, tuple of str): the names of args
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exception: when the type is not correct, otherwise nothing
|
||||||
|
"""
|
||||||
|
type_check(args, (list, tuple,), arg_names)
|
||||||
|
if len(args) != len(arg_names):
|
||||||
|
raise ValueError("List of arguments is not the same length as argument_names.")
|
||||||
|
for arg, arg_name in zip(args, arg_names):
|
||||||
|
type_check(arg, types, arg_name)
|
||||||
|
|
||||||
|
|
||||||
|
def type_check(arg, types, arg_name):
|
||||||
|
"""
|
||||||
|
Check the type of the parameter
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arg : any variable
|
||||||
|
types (tuple): tuple of all valid types for arg
|
||||||
|
arg_name (str): the name of arg
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exception: when the type is not correct, otherwise nothing
|
||||||
|
"""
|
||||||
|
# handle special case of booleans being a subclass of ints
|
||||||
|
print_value = '\"\"' if repr(arg) == repr('') else arg
|
||||||
|
|
||||||
|
if int in types and bool not in types:
|
||||||
|
if isinstance(arg, bool):
|
||||||
|
raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types))
|
||||||
|
if not isinstance(arg, types):
|
||||||
|
raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types))
|
||||||
|
|
||||||
|
|
||||||
|
def check_filename(path):
|
||||||
|
"""
|
||||||
|
check the filename in the path
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): the path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exception: when error
|
||||||
|
"""
|
||||||
|
if not isinstance(path, str):
|
||||||
|
raise TypeError("path: {} is not string".format(path))
|
||||||
|
filename = os.path.basename(path)
|
||||||
|
|
||||||
|
# '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`',
|
||||||
|
# '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>',
|
||||||
|
# '*', '(', '%', ')', '-', '=', '{', '?', '$'
|
||||||
|
forbidden_symbols = set(r'\/:*?"<>|`&\';')
|
||||||
|
|
||||||
|
if set(filename) & forbidden_symbols:
|
||||||
|
raise ValueError(r"filename should not contains \/:*?\"<>|`&;\'")
|
||||||
|
|
||||||
|
if filename.startswith(' ') or filename.endswith(' '):
|
||||||
|
raise ValueError("filename should not start/end with space")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def check_dir(dataset_dir):
|
||||||
|
if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK):
|
||||||
|
raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir))
|
||||||
|
|
||||||
|
|
||||||
|
def check_file(dataset_file):
|
||||||
|
check_filename(dataset_file)
|
||||||
|
if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK):
|
||||||
|
raise ValueError("The file {} does not exist or permission denied!".format(dataset_file))
|
||||||
|
|
||||||
|
|
||||||
|
def check_sampler_shuffle_shard_options(param_dict):
|
||||||
|
"""
|
||||||
|
Check for valid shuffle, sampler, num_shards, and shard_id inputs.
|
||||||
|
Args:
|
||||||
|
param_dict (dict): param_dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exception: ValueError or RuntimeError if error
|
||||||
|
"""
|
||||||
|
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
|
||||||
|
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
|
||||||
|
|
||||||
|
type_check(sampler, (type(None), samplers.BuiltinSampler, samplers.Sampler), "sampler")
|
||||||
|
|
||||||
|
if sampler is not None:
|
||||||
|
if shuffle is not None:
|
||||||
|
raise RuntimeError("sampler and shuffle cannot be specified at the same time.")
|
||||||
|
|
||||||
|
if num_shards is not None:
|
||||||
|
check_pos_int32(num_shards)
|
||||||
|
if shard_id is None:
|
||||||
|
raise RuntimeError("num_shards is specified and currently requires shard_id as well.")
|
||||||
|
check_value(shard_id, [0, num_shards - 1], "shard_id")
|
||||||
|
|
||||||
|
if num_shards is None and shard_id is not None:
|
||||||
|
raise RuntimeError("shard_id is specified but num_shards is not.")
|
||||||
|
|
||||||
|
|
||||||
|
def check_padding_options(param_dict):
|
||||||
|
"""
|
||||||
|
Check for valid padded_sample and num_padded of padded samples
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_dict (dict): param_dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exception: ValueError or RuntimeError if error
|
||||||
|
"""
|
||||||
|
|
||||||
|
columns_list = param_dict.get('columns_list')
|
||||||
|
block_reader = param_dict.get('block_reader')
|
||||||
|
padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded')
|
||||||
|
if padded_sample is not None:
|
||||||
|
if num_padded is None:
|
||||||
|
raise RuntimeError("padded_sample is specified and requires num_padded as well.")
|
||||||
|
if num_padded < 0:
|
||||||
|
raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded))
|
||||||
|
if columns_list is None:
|
||||||
|
raise RuntimeError("padded_sample is specified and requires columns_list as well.")
|
||||||
|
for column in columns_list:
|
||||||
|
if column not in padded_sample:
|
||||||
|
raise ValueError("padded_sample cannot match columns_list.")
|
||||||
|
if block_reader:
|
||||||
|
raise RuntimeError("block_reader and padded_sample cannot be specified at the same time.")
|
||||||
|
|
||||||
|
if padded_sample is None and num_padded is not None:
|
||||||
|
raise RuntimeError("num_padded is specified but padded_sample is not.")
|
||||||
|
|
||||||
|
|
||||||
|
def check_num_parallel_workers(value):
|
||||||
|
type_check(value, (int,), "num_parallel_workers")
|
||||||
|
if value < 1 or value > cpu_count():
|
||||||
|
raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count()))
|
||||||
|
|
||||||
|
|
||||||
|
def check_num_samples(value):
|
||||||
|
type_check(value, (int,), "num_samples")
|
||||||
|
check_value(value, [0, INT32_MAX], "num_samples")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_dataset_param_value(param_list, param_dict, param_type):
|
||||||
|
for param_name in param_list:
|
||||||
|
if param_dict.get(param_name) is not None:
|
||||||
|
if param_name == 'num_parallel_workers':
|
||||||
|
check_num_parallel_workers(param_dict.get(param_name))
|
||||||
|
if param_name == 'num_samples':
|
||||||
|
check_num_samples(param_dict.get(param_name))
|
||||||
|
else:
|
||||||
|
type_check(param_dict.get(param_name), (param_type,), param_name)
|
||||||
|
|
||||||
|
|
||||||
|
def check_gnn_list_or_ndarray(param, param_name):
|
||||||
|
"""
|
||||||
|
Check if the input parameter is list or numpy.ndarray.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param (list, nd.ndarray): param
|
||||||
|
param_name (str): param_name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exception: TypeError if error
|
||||||
|
"""
|
||||||
|
|
||||||
|
type_check(param, (list, np.ndarray), param_name)
|
||||||
|
if isinstance(param, list):
|
||||||
|
param_names = ["param_{0}".format(i) for i in range(len(param))]
|
||||||
|
type_check_list(param, (int,), param_names)
|
||||||
|
|
||||||
|
elif isinstance(param, np.ndarray):
|
||||||
|
if not param.dtype == np.int32:
|
||||||
|
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
|
||||||
|
param_name, param.dtype))
|
File diff suppressed because it is too large
Load Diff
|
@ -98,7 +98,7 @@ class Ngram(cde.NgramOp):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@check_ngram
|
@check_ngram
|
||||||
def __init__(self, n, left_pad=None, right_pad=None, separator=None):
|
def __init__(self, n, left_pad=("", 0), right_pad=("", 0), separator=" "):
|
||||||
super().__init__(ngrams=n, l_pad_len=left_pad[1], r_pad_len=right_pad[1], l_pad_token=left_pad[0],
|
super().__init__(ngrams=n, l_pad_len=left_pad[1], r_pad_len=right_pad[1], l_pad_token=left_pad[0],
|
||||||
r_pad_token=right_pad[0], separator=separator)
|
r_pad_token=right_pad[0], separator=separator)
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ __all__ = [
|
||||||
"Vocab", "to_str", "to_bytes"
|
"Vocab", "to_str", "to_bytes"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Vocab(cde.Vocab):
|
class Vocab(cde.Vocab):
|
||||||
"""
|
"""
|
||||||
Vocab object that is used to lookup a word.
|
Vocab object that is used to lookup a word.
|
||||||
|
@ -38,7 +39,7 @@ class Vocab(cde.Vocab):
|
||||||
@classmethod
|
@classmethod
|
||||||
@check_from_dataset
|
@check_from_dataset
|
||||||
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None,
|
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None,
|
||||||
special_first=None):
|
special_first=True):
|
||||||
"""
|
"""
|
||||||
Build a vocab from a dataset.
|
Build a vocab from a dataset.
|
||||||
|
|
||||||
|
@ -62,13 +63,21 @@ class Vocab(cde.Vocab):
|
||||||
special_tokens(list, optional): a list of strings, each one is a special token. for example
|
special_tokens(list, optional): a list of strings, each one is a special token. for example
|
||||||
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
|
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
|
||||||
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens
|
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens
|
||||||
is specified and special_first is set to None, special_tokens will be prepended (default=None).
|
is specified and special_first is set to True, special_tokens will be prepended (default=True).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Vocab, Vocab object built from dataset.
|
Vocab, Vocab object built from dataset.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vocab = Vocab()
|
vocab = Vocab()
|
||||||
|
if columns is None:
|
||||||
|
columns = []
|
||||||
|
if not isinstance(columns, list):
|
||||||
|
columns = [columns]
|
||||||
|
if freq_range is None:
|
||||||
|
freq_range = (None, None)
|
||||||
|
if special_tokens is None:
|
||||||
|
special_tokens = []
|
||||||
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first)
|
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first)
|
||||||
for d in root.create_dict_iterator():
|
for d in root.create_dict_iterator():
|
||||||
if d is not None:
|
if d is not None:
|
||||||
|
@ -77,7 +86,7 @@ class Vocab(cde.Vocab):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@check_from_list
|
@check_from_list
|
||||||
def from_list(cls, word_list, special_tokens=None, special_first=None):
|
def from_list(cls, word_list, special_tokens=None, special_first=True):
|
||||||
"""
|
"""
|
||||||
Build a vocab object from a list of word.
|
Build a vocab object from a list of word.
|
||||||
|
|
||||||
|
@ -86,29 +95,33 @@ class Vocab(cde.Vocab):
|
||||||
special_tokens(list, optional): a list of strings, each one is a special token. for example
|
special_tokens(list, optional): a list of strings, each one is a special token. for example
|
||||||
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
|
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
|
||||||
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens
|
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens
|
||||||
is specified and special_first is set to None, special_tokens will be prepended (default=None).
|
is specified and special_first is set to True, special_tokens will be prepended (default=True).
|
||||||
"""
|
"""
|
||||||
|
if special_tokens is None:
|
||||||
|
special_tokens = []
|
||||||
return super().from_list(word_list, special_tokens, special_first)
|
return super().from_list(word_list, special_tokens, special_first)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@check_from_file
|
@check_from_file
|
||||||
def from_file(cls, file_path, delimiter=None, vocab_size=None, special_tokens=None, special_first=None):
|
def from_file(cls, file_path, delimiter="", vocab_size=None, special_tokens=None, special_first=True):
|
||||||
"""
|
"""
|
||||||
Build a vocab object from a list of word.
|
Build a vocab object from a list of word.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path (str): path to the file which contains the vocab list.
|
file_path (str): path to the file which contains the vocab list.
|
||||||
delimiter (str, optional): a delimiter to break up each line in file, the first element is taken to be
|
delimiter (str, optional): a delimiter to break up each line in file, the first element is taken to be
|
||||||
the word (default=None).
|
the word (default="").
|
||||||
vocab_size (int, optional): number of words to read from file_path (default=None, all words are taken).
|
vocab_size (int, optional): number of words to read from file_path (default=None, all words are taken).
|
||||||
special_tokens (list, optional): a list of strings, each one is a special token. for example
|
special_tokens (list, optional): a list of strings, each one is a special token. for example
|
||||||
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
|
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
|
||||||
special_first (bool, optional): whether special_tokens will be prepended/appended to vocab,
|
special_first (bool, optional): whether special_tokens will be prepended/appended to vocab,
|
||||||
If special_tokens is specified and special_first is set to None,
|
If special_tokens is specified and special_first is set to True,
|
||||||
special_tokens will be prepended (default=None).
|
special_tokens will be prepended (default=True).
|
||||||
"""
|
"""
|
||||||
|
if vocab_size is None:
|
||||||
|
vocab_size = -1
|
||||||
|
if special_tokens is None:
|
||||||
|
special_tokens = []
|
||||||
return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first)
|
return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -17,23 +17,22 @@ validators for text ops
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
import mindspore._c_dataengine as cde
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
|
import mindspore._c_dataengine as cde
|
||||||
from mindspore._c_expression import typing
|
from mindspore._c_expression import typing
|
||||||
from ..transforms.validators import check_uint32, check_pos_int64
|
|
||||||
|
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, check_positive, \
|
||||||
|
INT32_MAX, check_value
|
||||||
|
|
||||||
|
|
||||||
def check_unique_list_of_words(words, arg_name):
|
def check_unique_list_of_words(words, arg_name):
|
||||||
"""Check that words is a list and each element is a str without any duplication"""
|
"""Check that words is a list and each element is a str without any duplication"""
|
||||||
|
|
||||||
if not isinstance(words, list):
|
type_check(words, (list,), arg_name)
|
||||||
raise ValueError(arg_name + " needs to be a list of words of type string.")
|
|
||||||
words_set = set()
|
words_set = set()
|
||||||
for word in words:
|
for word in words:
|
||||||
if not isinstance(word, str):
|
type_check(word, (str,), arg_name)
|
||||||
raise ValueError("each word in " + arg_name + " needs to be type str.")
|
|
||||||
if word in words_set:
|
if word in words_set:
|
||||||
raise ValueError(arg_name + " contains duplicate word: " + word + ".")
|
raise ValueError(arg_name + " contains duplicate word: " + word + ".")
|
||||||
words_set.add(word)
|
words_set.add(word)
|
||||||
|
@ -45,21 +44,14 @@ def check_lookup(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
vocab, unknown = (list(args) + 2 * [None])[:2]
|
[vocab, unknown], _ = parse_user_args(method, *args, **kwargs)
|
||||||
if "vocab" in kwargs:
|
|
||||||
vocab = kwargs.get("vocab")
|
|
||||||
if "unknown" in kwargs:
|
|
||||||
unknown = kwargs.get("unknown")
|
|
||||||
if unknown is not None:
|
if unknown is not None:
|
||||||
if not (isinstance(unknown, int) and unknown >= 0):
|
type_check(unknown, (int,), "unknown")
|
||||||
raise ValueError("unknown needs to be a non-negative integer.")
|
check_positive(unknown)
|
||||||
|
type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.")
|
||||||
|
|
||||||
if not isinstance(vocab, cde.Vocab):
|
return method(self, *args, **kwargs)
|
||||||
raise ValueError("vocab is not an instance of cde.Vocab.")
|
|
||||||
|
|
||||||
kwargs["vocab"] = vocab
|
|
||||||
kwargs["unknown"] = unknown
|
|
||||||
return method(self, **kwargs)
|
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -69,50 +61,15 @@ def check_from_file(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
file_path, delimiter, vocab_size, special_tokens, special_first = (list(args) + 5 * [None])[:5]
|
[file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args,
|
||||||
if "file_path" in kwargs:
|
**kwargs)
|
||||||
file_path = kwargs.get("file_path")
|
|
||||||
if "delimiter" in kwargs:
|
|
||||||
delimiter = kwargs.get("delimiter")
|
|
||||||
if "vocab_size" in kwargs:
|
|
||||||
vocab_size = kwargs.get("vocab_size")
|
|
||||||
if "special_tokens" in kwargs:
|
|
||||||
special_tokens = kwargs.get("special_tokens")
|
|
||||||
if "special_first" in kwargs:
|
|
||||||
special_first = kwargs.get("special_first")
|
|
||||||
|
|
||||||
if not isinstance(file_path, str):
|
|
||||||
raise ValueError("file_path needs to be str.")
|
|
||||||
|
|
||||||
if delimiter is not None:
|
|
||||||
if not isinstance(delimiter, str):
|
|
||||||
raise ValueError("delimiter needs to be str.")
|
|
||||||
else:
|
|
||||||
delimiter = ""
|
|
||||||
if vocab_size is not None:
|
|
||||||
if not (isinstance(vocab_size, int) and vocab_size > 0):
|
|
||||||
raise ValueError("vocab size needs to be a positive integer.")
|
|
||||||
else:
|
|
||||||
vocab_size = -1
|
|
||||||
|
|
||||||
if special_first is None:
|
|
||||||
special_first = True
|
|
||||||
|
|
||||||
if not isinstance(special_first, bool):
|
|
||||||
raise ValueError("special_first needs to be a boolean value")
|
|
||||||
|
|
||||||
if special_tokens is None:
|
|
||||||
special_tokens = []
|
|
||||||
|
|
||||||
check_unique_list_of_words(special_tokens, "special_tokens")
|
check_unique_list_of_words(special_tokens, "special_tokens")
|
||||||
|
type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"])
|
||||||
|
if vocab_size is not None:
|
||||||
|
check_value(vocab_size, (-1, INT32_MAX), "vocab_size")
|
||||||
|
type_check(special_first, (bool,), special_first)
|
||||||
|
|
||||||
kwargs["file_path"] = file_path
|
return method(self, *args, **kwargs)
|
||||||
kwargs["delimiter"] = delimiter
|
|
||||||
kwargs["vocab_size"] = vocab_size
|
|
||||||
kwargs["special_tokens"] = special_tokens
|
|
||||||
kwargs["special_first"] = special_first
|
|
||||||
|
|
||||||
return method(self, **kwargs)
|
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -122,33 +79,20 @@ def check_from_list(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
word_list, special_tokens, special_first = (list(args) + 3 * [None])[:3]
|
[word_list, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs)
|
||||||
if "word_list" in kwargs:
|
|
||||||
word_list = kwargs.get("word_list")
|
|
||||||
if "special_tokens" in kwargs:
|
|
||||||
special_tokens = kwargs.get("special_tokens")
|
|
||||||
if "special_first" in kwargs:
|
|
||||||
special_first = kwargs.get("special_first")
|
|
||||||
if special_tokens is None:
|
|
||||||
special_tokens = []
|
|
||||||
word_set = check_unique_list_of_words(word_list, "word_list")
|
word_set = check_unique_list_of_words(word_list, "word_list")
|
||||||
token_set = check_unique_list_of_words(special_tokens, "special_tokens")
|
if special_tokens is not None:
|
||||||
|
token_set = check_unique_list_of_words(special_tokens, "special_tokens")
|
||||||
|
|
||||||
intersect = word_set.intersection(token_set)
|
intersect = word_set.intersection(token_set)
|
||||||
|
|
||||||
if intersect != set():
|
if intersect != set():
|
||||||
raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".")
|
raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".")
|
||||||
|
|
||||||
if special_first is None:
|
type_check(special_first, (bool,), "special_first")
|
||||||
special_first = True
|
|
||||||
|
|
||||||
if not isinstance(special_first, bool):
|
return method(self, *args, **kwargs)
|
||||||
raise ValueError("special_first needs to be a boolean value.")
|
|
||||||
|
|
||||||
kwargs["word_list"] = word_list
|
|
||||||
kwargs["special_tokens"] = special_tokens
|
|
||||||
kwargs["special_first"] = special_first
|
|
||||||
return method(self, **kwargs)
|
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -158,18 +102,15 @@ def check_from_dict(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
word_dict, = (list(args) + [None])[:1]
|
[word_dict], _ = parse_user_args(method, *args, **kwargs)
|
||||||
if "word_dict" in kwargs:
|
|
||||||
word_dict = kwargs.get("word_dict")
|
type_check(word_dict, (dict,), "word_dict")
|
||||||
if not isinstance(word_dict, dict):
|
|
||||||
raise ValueError("word_dict needs to be a list of word,id pairs.")
|
|
||||||
for word, word_id in word_dict.items():
|
for word, word_id in word_dict.items():
|
||||||
if not isinstance(word, str):
|
type_check(word, (str,), "word")
|
||||||
raise ValueError("Each word in word_dict needs to be type string.")
|
type_check(word_id, (int,), "word_id")
|
||||||
if not (isinstance(word_id, int) and word_id >= 0):
|
check_value(word_id, (-1, INT32_MAX), "word_id")
|
||||||
raise ValueError("Each word id needs to be positive integer.")
|
return method(self, *args, **kwargs)
|
||||||
kwargs["word_dict"] = word_dict
|
|
||||||
return method(self, **kwargs)
|
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -179,23 +120,8 @@ def check_jieba_init(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
hmm_path, mp_path, model = (list(args) + 3 * [None])[:3]
|
parse_user_args(method, *args, **kwargs)
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
if "hmm_path" in kwargs:
|
|
||||||
hmm_path = kwargs.get("hmm_path")
|
|
||||||
if "mp_path" in kwargs:
|
|
||||||
mp_path = kwargs.get("mp_path")
|
|
||||||
if hmm_path is None:
|
|
||||||
raise ValueError(
|
|
||||||
"The dict of HMMSegment in cppjieba is not provided.")
|
|
||||||
kwargs["hmm_path"] = hmm_path
|
|
||||||
if mp_path is None:
|
|
||||||
raise ValueError(
|
|
||||||
"The dict of MPSegment in cppjieba is not provided.")
|
|
||||||
kwargs["mp_path"] = mp_path
|
|
||||||
if model is not None:
|
|
||||||
kwargs["model"] = model
|
|
||||||
return method(self, **kwargs)
|
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -205,19 +131,12 @@ def check_jieba_add_word(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
word, freq = (list(args) + 2 * [None])[:2]
|
[word, freq], _ = parse_user_args(method, *args, **kwargs)
|
||||||
|
|
||||||
if "word" in kwargs:
|
|
||||||
word = kwargs.get("word")
|
|
||||||
if "freq" in kwargs:
|
|
||||||
freq = kwargs.get("freq")
|
|
||||||
if word is None:
|
if word is None:
|
||||||
raise ValueError("word is not provided.")
|
raise ValueError("word is not provided.")
|
||||||
kwargs["word"] = word
|
|
||||||
if freq is not None:
|
if freq is not None:
|
||||||
check_uint32(freq)
|
check_uint32(freq)
|
||||||
kwargs["freq"] = freq
|
return method(self, *args, **kwargs)
|
||||||
return method(self, **kwargs)
|
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -227,13 +146,8 @@ def check_jieba_add_dict(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
user_dict = (list(args) + [None])[0]
|
parse_user_args(method, *args, **kwargs)
|
||||||
if "user_dict" in kwargs:
|
return method(self, *args, **kwargs)
|
||||||
user_dict = kwargs.get("user_dict")
|
|
||||||
if user_dict is None:
|
|
||||||
raise ValueError("user_dict is not provided.")
|
|
||||||
kwargs["user_dict"] = user_dict
|
|
||||||
return method(self, **kwargs)
|
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -244,69 +158,39 @@ def check_from_dataset(method):
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
|
||||||
dataset, columns, freq_range, top_k, special_tokens, special_first = (list(args) + 6 * [None])[:6]
|
[_, columns, freq_range, top_k, special_tokens, special_first], _ = parse_user_args(method, *args,
|
||||||
if "dataset" in kwargs:
|
**kwargs)
|
||||||
dataset = kwargs.get("dataset")
|
if columns is not None:
|
||||||
if "columns" in kwargs:
|
if not isinstance(columns, list):
|
||||||
columns = kwargs.get("columns")
|
columns = [columns]
|
||||||
if "freq_range" in kwargs:
|
col_names = ["col_{0}".format(i) for i in range(len(columns))]
|
||||||
freq_range = kwargs.get("freq_range")
|
type_check_list(columns, (str,), col_names)
|
||||||
if "top_k" in kwargs:
|
|
||||||
top_k = kwargs.get("top_k")
|
|
||||||
if "special_tokens" in kwargs:
|
|
||||||
special_tokens = kwargs.get("special_tokens")
|
|
||||||
if "special_first" in kwargs:
|
|
||||||
special_first = kwargs.get("special_first")
|
|
||||||
|
|
||||||
if columns is None:
|
if freq_range is not None:
|
||||||
columns = []
|
type_check(freq_range, (tuple,), "freq_range")
|
||||||
|
|
||||||
if not isinstance(columns, list):
|
if len(freq_range) != 2:
|
||||||
columns = [columns]
|
raise ValueError("freq_range needs to be a tuple of 2 integers or an int and a None.")
|
||||||
|
|
||||||
for column in columns:
|
for num in freq_range:
|
||||||
if not isinstance(column, str):
|
if num is not None and (not isinstance(num, int)):
|
||||||
raise ValueError("columns need to be a list of strings.")
|
raise ValueError(
|
||||||
|
"freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
|
||||||
|
|
||||||
if freq_range is None:
|
if isinstance(freq_range[0], int) and isinstance(freq_range[1], int):
|
||||||
freq_range = (None, None)
|
if freq_range[0] > freq_range[1] or freq_range[0] < 0:
|
||||||
|
raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).")
|
||||||
|
|
||||||
if not isinstance(freq_range, tuple) or len(freq_range) != 2:
|
type_check(top_k, (int, type(None)), "top_k")
|
||||||
raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
|
|
||||||
|
|
||||||
for num in freq_range:
|
if isinstance(top_k, int):
|
||||||
if num is not None and (not isinstance(num, int)):
|
check_value(top_k, (0, INT32_MAX), "top_k")
|
||||||
raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
|
type_check(special_first, (bool,), "special_first")
|
||||||
|
|
||||||
if isinstance(freq_range[0], int) and isinstance(freq_range[1], int):
|
if special_tokens is not None:
|
||||||
if freq_range[0] > freq_range[1] or freq_range[0] < 0:
|
check_unique_list_of_words(special_tokens, "special_tokens")
|
||||||
raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).")
|
|
||||||
|
|
||||||
if top_k is not None and (not isinstance(top_k, int)):
|
return method(self, *args, **kwargs)
|
||||||
raise ValueError("top_k needs to be a positive integer.")
|
|
||||||
|
|
||||||
if isinstance(top_k, int) and top_k <= 0:
|
|
||||||
raise ValueError("top_k needs to be a positive integer.")
|
|
||||||
|
|
||||||
if special_first is None:
|
|
||||||
special_first = True
|
|
||||||
|
|
||||||
if special_tokens is None:
|
|
||||||
special_tokens = []
|
|
||||||
|
|
||||||
if not isinstance(special_first, bool):
|
|
||||||
raise ValueError("special_first needs to be a boolean value.")
|
|
||||||
|
|
||||||
check_unique_list_of_words(special_tokens, "special_tokens")
|
|
||||||
|
|
||||||
kwargs["dataset"] = dataset
|
|
||||||
kwargs["columns"] = columns
|
|
||||||
kwargs["freq_range"] = freq_range
|
|
||||||
kwargs["top_k"] = top_k
|
|
||||||
kwargs["special_tokens"] = special_tokens
|
|
||||||
kwargs["special_first"] = special_first
|
|
||||||
|
|
||||||
return method(self, **kwargs)
|
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -316,15 +200,7 @@ def check_ngram(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
n, left_pad, right_pad, separator = (list(args) + 4 * [None])[:4]
|
[n, left_pad, right_pad, separator], _ = parse_user_args(method, *args, **kwargs)
|
||||||
if "n" in kwargs:
|
|
||||||
n = kwargs.get("n")
|
|
||||||
if "left_pad" in kwargs:
|
|
||||||
left_pad = kwargs.get("left_pad")
|
|
||||||
if "right_pad" in kwargs:
|
|
||||||
right_pad = kwargs.get("right_pad")
|
|
||||||
if "separator" in kwargs:
|
|
||||||
separator = kwargs.get("separator")
|
|
||||||
|
|
||||||
if isinstance(n, int):
|
if isinstance(n, int):
|
||||||
n = [n]
|
n = [n]
|
||||||
|
@ -332,15 +208,9 @@ def check_ngram(method):
|
||||||
if not (isinstance(n, list) and n != []):
|
if not (isinstance(n, list) and n != []):
|
||||||
raise ValueError("n needs to be a non-empty list of positive integers.")
|
raise ValueError("n needs to be a non-empty list of positive integers.")
|
||||||
|
|
||||||
for gram in n:
|
for i, gram in enumerate(n):
|
||||||
if not (isinstance(gram, int) and gram > 0):
|
type_check(gram, (int,), "gram[{0}]".format(i))
|
||||||
raise ValueError("n in ngram needs to be a positive number.")
|
check_value(gram, (0, INT32_MAX), "gram_{}".format(i))
|
||||||
|
|
||||||
if left_pad is None:
|
|
||||||
left_pad = ("", 0)
|
|
||||||
|
|
||||||
if right_pad is None:
|
|
||||||
right_pad = ("", 0)
|
|
||||||
|
|
||||||
if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance(
|
if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance(
|
||||||
left_pad[1], int)):
|
left_pad[1], int)):
|
||||||
|
@ -353,11 +223,7 @@ def check_ngram(method):
|
||||||
if not (left_pad[1] >= 0 and right_pad[1] >= 0):
|
if not (left_pad[1] >= 0 and right_pad[1] >= 0):
|
||||||
raise ValueError("padding width need to be positive numbers.")
|
raise ValueError("padding width need to be positive numbers.")
|
||||||
|
|
||||||
if separator is None:
|
type_check(separator, (str,), "separator")
|
||||||
separator = " "
|
|
||||||
|
|
||||||
if not isinstance(separator, str):
|
|
||||||
raise ValueError("separator needs to be a string.")
|
|
||||||
|
|
||||||
kwargs["n"] = n
|
kwargs["n"] = n
|
||||||
kwargs["left_pad"] = left_pad
|
kwargs["left_pad"] = left_pad
|
||||||
|
@ -374,16 +240,8 @@ def check_pair_truncate(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
max_length = (list(args) + [None])[0]
|
parse_user_args(method, *args, **kwargs)
|
||||||
if "max_length" in kwargs:
|
return method(self, *args, **kwargs)
|
||||||
max_length = kwargs.get("max_length")
|
|
||||||
if max_length is None:
|
|
||||||
raise ValueError("max_length is not provided.")
|
|
||||||
|
|
||||||
check_pos_int64(max_length)
|
|
||||||
kwargs["max_length"] = max_length
|
|
||||||
|
|
||||||
return method(self, **kwargs)
|
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -393,22 +251,13 @@ def check_to_number(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
data_type = (list(args) + [None])[0]
|
[data_type], _ = parse_user_args(method, *args, **kwargs)
|
||||||
if "data_type" in kwargs:
|
type_check(data_type, (typing.Type,), "data_type")
|
||||||
data_type = kwargs.get("data_type")
|
|
||||||
|
|
||||||
if data_type is None:
|
|
||||||
raise ValueError("data_type is a mandatory parameter but was not provided.")
|
|
||||||
|
|
||||||
if not isinstance(data_type, typing.Type):
|
|
||||||
raise TypeError("data_type is not a MindSpore data type.")
|
|
||||||
|
|
||||||
if data_type not in mstype.number_type:
|
if data_type not in mstype.number_type:
|
||||||
raise TypeError("data_type is not numeric data type.")
|
raise TypeError("data_type is not numeric data type.")
|
||||||
|
|
||||||
kwargs["data_type"] = data_type
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return method(self, **kwargs)
|
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -418,18 +267,11 @@ def check_python_tokenizer(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
tokenizer = (list(args) + [None])[0]
|
[tokenizer], _ = parse_user_args(method, *args, **kwargs)
|
||||||
if "tokenizer" in kwargs:
|
|
||||||
tokenizer = kwargs.get("tokenizer")
|
|
||||||
|
|
||||||
if tokenizer is None:
|
|
||||||
raise ValueError("tokenizer is a mandatory parameter.")
|
|
||||||
|
|
||||||
if not callable(tokenizer):
|
if not callable(tokenizer):
|
||||||
raise TypeError("tokenizer is not a callable python function")
|
raise TypeError("tokenizer is not a callable python function")
|
||||||
|
|
||||||
kwargs["tokenizer"] = tokenizer
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return method(self, **kwargs)
|
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
|
@ -18,6 +18,7 @@ from functools import wraps
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore._c_expression import typing
|
from mindspore._c_expression import typing
|
||||||
|
from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive
|
||||||
|
|
||||||
# POS_INT_MIN is used to limit values from starting from 0
|
# POS_INT_MIN is used to limit values from starting from 0
|
||||||
POS_INT_MIN = 1
|
POS_INT_MIN = 1
|
||||||
|
@ -37,106 +38,33 @@ DOUBLE_MAX_INTEGER = 9007199254740992
|
||||||
DOUBLE_MIN_INTEGER = -9007199254740992
|
DOUBLE_MIN_INTEGER = -9007199254740992
|
||||||
|
|
||||||
|
|
||||||
def check_type(value, valid_type):
|
def check_fill_value(method):
|
||||||
if not isinstance(value, valid_type):
|
"""Wrapper method to check the parameters of fill_value."""
|
||||||
raise ValueError("Wrong input type")
|
|
||||||
|
|
||||||
|
|
||||||
def check_value(value, valid_range):
|
|
||||||
if value < valid_range[0] or value > valid_range[1]:
|
|
||||||
raise ValueError("Input is not within the required range")
|
|
||||||
|
|
||||||
|
|
||||||
def check_range(values, valid_range):
|
|
||||||
if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]:
|
|
||||||
raise ValueError("Input range is not valid")
|
|
||||||
|
|
||||||
|
|
||||||
def check_positive(value):
|
|
||||||
if value <= 0:
|
|
||||||
raise ValueError("Input must greater than 0")
|
|
||||||
|
|
||||||
|
|
||||||
def check_positive_float(value, valid_max=None):
|
|
||||||
if value <= 0 or not isinstance(value, float) or (valid_max is not None and value > valid_max):
|
|
||||||
raise ValueError("Input need to be a valid positive float.")
|
|
||||||
|
|
||||||
|
|
||||||
def check_bool(value):
|
|
||||||
if not isinstance(value, bool):
|
|
||||||
raise ValueError("Value needs to be a boolean.")
|
|
||||||
|
|
||||||
|
|
||||||
def check_2tuple(value):
|
|
||||||
if not (isinstance(value, tuple) and len(value) == 2):
|
|
||||||
raise ValueError("Value needs to be a 2-tuple.")
|
|
||||||
|
|
||||||
|
|
||||||
def check_list(value):
|
|
||||||
if not isinstance(value, list):
|
|
||||||
raise ValueError("The input needs to be a list.")
|
|
||||||
|
|
||||||
|
|
||||||
def check_uint8(value):
|
|
||||||
if not isinstance(value, int):
|
|
||||||
raise ValueError("The input needs to be a integer")
|
|
||||||
check_value(value, [UINT8_MIN, UINT8_MAX])
|
|
||||||
|
|
||||||
|
|
||||||
def check_uint32(value):
|
|
||||||
if not isinstance(value, int):
|
|
||||||
raise ValueError("The input needs to be a integer")
|
|
||||||
check_value(value, [UINT32_MIN, UINT32_MAX])
|
|
||||||
|
|
||||||
|
|
||||||
def check_pos_int32(value):
|
|
||||||
"""Checks for int values starting from 1"""
|
|
||||||
if not isinstance(value, int):
|
|
||||||
raise ValueError("The input needs to be a integer")
|
|
||||||
check_value(value, [POS_INT_MIN, INT32_MAX])
|
|
||||||
|
|
||||||
|
|
||||||
def check_uint64(value):
|
|
||||||
if not isinstance(value, int):
|
|
||||||
raise ValueError("The input needs to be a integer")
|
|
||||||
check_value(value, [UINT64_MIN, UINT64_MAX])
|
|
||||||
|
|
||||||
|
|
||||||
def check_pos_int64(value):
|
|
||||||
if not isinstance(value, int):
|
|
||||||
raise ValueError("The input needs to be a integer")
|
|
||||||
check_value(value, [UINT64_MIN, INT64_MAX])
|
|
||||||
|
|
||||||
|
|
||||||
def check_pos_float32(value):
|
|
||||||
check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER])
|
|
||||||
|
|
||||||
|
|
||||||
def check_pos_float64(value):
|
|
||||||
check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER])
|
|
||||||
|
|
||||||
|
|
||||||
def check_one_hot_op(method):
|
|
||||||
"""Wrapper method to check the parameters of one hot op."""
|
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
args = (list(args) + 2 * [None])[:2]
|
[fill_value], _ = parse_user_args(method, *args, **kwargs)
|
||||||
num_classes, smoothing_rate = args
|
type_check(fill_value, (str, float, bool, int, bytes), "fill_value")
|
||||||
if "num_classes" in kwargs:
|
|
||||||
num_classes = kwargs.get("num_classes")
|
return method(self, *args, **kwargs)
|
||||||
if "smoothing_rate" in kwargs:
|
|
||||||
smoothing_rate = kwargs.get("smoothing_rate")
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_one_hot_op(method):
|
||||||
|
"""Wrapper method to check the parameters of one_hot_op."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
[num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs)
|
||||||
|
|
||||||
|
type_check(num_classes, (int,), "num_classes")
|
||||||
|
check_positive(num_classes)
|
||||||
|
|
||||||
if num_classes is None:
|
|
||||||
raise ValueError("num_classes")
|
|
||||||
check_pos_int32(num_classes)
|
|
||||||
kwargs["num_classes"] = num_classes
|
|
||||||
if smoothing_rate is not None:
|
if smoothing_rate is not None:
|
||||||
check_value(smoothing_rate, [0., 1.])
|
check_value(smoothing_rate, [0., 1.], "smoothing_rate")
|
||||||
kwargs["smoothing_rate"] = smoothing_rate
|
|
||||||
|
|
||||||
return method(self, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -146,35 +74,12 @@ def check_num_classes(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
num_classes = (list(args) + [None])[0]
|
[num_classes], _ = parse_user_args(method, *args, **kwargs)
|
||||||
if "num_classes" in kwargs:
|
|
||||||
num_classes = kwargs.get("num_classes")
|
|
||||||
if num_classes is None:
|
|
||||||
raise ValueError("num_classes is not provided.")
|
|
||||||
|
|
||||||
check_pos_int32(num_classes)
|
type_check(num_classes, (int,), "num_classes")
|
||||||
kwargs["num_classes"] = num_classes
|
check_positive(num_classes)
|
||||||
|
|
||||||
return method(self, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
|
||||||
|
|
||||||
|
|
||||||
def check_fill_value(method):
|
|
||||||
"""Wrapper method to check the parameters of fill value."""
|
|
||||||
|
|
||||||
@wraps(method)
|
|
||||||
def new_method(self, *args, **kwargs):
|
|
||||||
fill_value = (list(args) + [None])[0]
|
|
||||||
if "fill_value" in kwargs:
|
|
||||||
fill_value = kwargs.get("fill_value")
|
|
||||||
if fill_value is None:
|
|
||||||
raise ValueError("fill_value is not provided.")
|
|
||||||
if not isinstance(fill_value, (str, float, bool, int, bytes)):
|
|
||||||
raise TypeError("fill_value must be either a primitive python str, float, bool, bytes or int")
|
|
||||||
kwargs["fill_value"] = fill_value
|
|
||||||
|
|
||||||
return method(self, **kwargs)
|
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -184,17 +89,11 @@ def check_de_type(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
data_type = (list(args) + [None])[0]
|
[data_type], _ = parse_user_args(method, *args, **kwargs)
|
||||||
if "data_type" in kwargs:
|
|
||||||
data_type = kwargs.get("data_type")
|
|
||||||
|
|
||||||
if data_type is None:
|
type_check(data_type, (typing.Type,), "data_type")
|
||||||
raise ValueError("data_type is not provided.")
|
|
||||||
if not isinstance(data_type, typing.Type):
|
|
||||||
raise TypeError("data_type is not a MindSpore data type.")
|
|
||||||
kwargs["data_type"] = data_type
|
|
||||||
|
|
||||||
return method(self, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -204,13 +103,11 @@ def check_slice_op(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args):
|
def new_method(self, *args):
|
||||||
for i, arg in enumerate(args):
|
for _, arg in enumerate(args):
|
||||||
if arg is not None and arg is not Ellipsis and not isinstance(arg, (int, slice, list)):
|
type_check(arg, (int, slice, list, type(None), type(Ellipsis)), "arg")
|
||||||
raise TypeError("Indexing of dim " + str(i) + "is not of valid type")
|
|
||||||
if isinstance(arg, list):
|
if isinstance(arg, list):
|
||||||
for a in arg:
|
for a in arg:
|
||||||
if not isinstance(a, int):
|
type_check(a, (int,), "a")
|
||||||
raise TypeError("Index " + a + " is not an int")
|
|
||||||
return method(self, *args)
|
return method(self, *args)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
@ -221,36 +118,14 @@ def check_mask_op(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
operator, constant, dtype = (list(args) + 3 * [None])[:3]
|
[operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs)
|
||||||
if "operator" in kwargs:
|
|
||||||
operator = kwargs.get("operator")
|
|
||||||
if "constant" in kwargs:
|
|
||||||
constant = kwargs.get("constant")
|
|
||||||
if "dtype" in kwargs:
|
|
||||||
dtype = kwargs.get("dtype")
|
|
||||||
|
|
||||||
if operator is None:
|
|
||||||
raise ValueError("operator is not provided.")
|
|
||||||
|
|
||||||
if constant is None:
|
|
||||||
raise ValueError("constant is not provided.")
|
|
||||||
|
|
||||||
from .c_transforms import Relational
|
from .c_transforms import Relational
|
||||||
if not isinstance(operator, Relational):
|
type_check(operator, (Relational,), "operator")
|
||||||
raise TypeError("operator is not a Relational operator enum.")
|
type_check(constant, (str, float, bool, int, bytes), "constant")
|
||||||
|
type_check(dtype, (typing.Type,), "dtype")
|
||||||
|
|
||||||
if not isinstance(constant, (str, float, bool, int, bytes)):
|
return method(self, *args, **kwargs)
|
||||||
raise TypeError("constant must be either a primitive python str, float, bool, bytes or int")
|
|
||||||
|
|
||||||
if dtype is not None:
|
|
||||||
if not isinstance(dtype, typing.Type):
|
|
||||||
raise TypeError("dtype is not a MindSpore data type.")
|
|
||||||
kwargs["dtype"] = dtype
|
|
||||||
|
|
||||||
kwargs["operator"] = operator
|
|
||||||
kwargs["constant"] = constant
|
|
||||||
|
|
||||||
return method(self, **kwargs)
|
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -260,22 +135,12 @@ def check_pad_end(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
pad_shape, pad_value = (list(args) + 2 * [None])[:2]
|
|
||||||
if "pad_shape" in kwargs:
|
|
||||||
pad_shape = kwargs.get("pad_shape")
|
|
||||||
if "pad_value" in kwargs:
|
|
||||||
pad_value = kwargs.get("pad_value")
|
|
||||||
|
|
||||||
if pad_shape is None:
|
[pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs)
|
||||||
raise ValueError("pad_shape is not provided.")
|
|
||||||
|
|
||||||
if pad_value is not None:
|
if pad_value is not None:
|
||||||
if not isinstance(pad_value, (str, float, bool, int, bytes)):
|
type_check(pad_value, (str, float, bool, int, bytes), "pad_value")
|
||||||
raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes")
|
type_check(pad_shape, (list,), "pad_end")
|
||||||
kwargs["pad_value"] = pad_value
|
|
||||||
|
|
||||||
if not isinstance(pad_shape, list):
|
|
||||||
raise TypeError("pad_shape must be a list")
|
|
||||||
|
|
||||||
for dim in pad_shape:
|
for dim in pad_shape:
|
||||||
if dim is not None:
|
if dim is not None:
|
||||||
|
@ -284,9 +149,7 @@ def check_pad_end(method):
|
||||||
else:
|
else:
|
||||||
raise TypeError("a value in the list is not an integer.")
|
raise TypeError("a value in the list is not an integer.")
|
||||||
|
|
||||||
kwargs["pad_shape"] = pad_shape
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return method(self, **kwargs)
|
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
@ -296,31 +159,24 @@ def check_concat_type(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
axis, prepend, append = (list(args) + 3 * [None])[:3]
|
|
||||||
if "prepend" in kwargs:
|
[axis, prepend, append], _ = parse_user_args(method, *args, **kwargs)
|
||||||
prepend = kwargs.get("prepend")
|
|
||||||
if "append" in kwargs:
|
|
||||||
append = kwargs.get("append")
|
|
||||||
if "axis" in kwargs:
|
|
||||||
axis = kwargs.get("axis")
|
|
||||||
|
|
||||||
if axis is not None:
|
if axis is not None:
|
||||||
if not isinstance(axis, int):
|
type_check(axis, (int,), "axis")
|
||||||
raise TypeError("axis type is not valid, must be an integer.")
|
|
||||||
if axis not in (0, -1):
|
if axis not in (0, -1):
|
||||||
raise ValueError("only 1D concatenation supported.")
|
raise ValueError("only 1D concatenation supported.")
|
||||||
kwargs["axis"] = axis
|
|
||||||
|
|
||||||
if prepend is not None:
|
if prepend is not None:
|
||||||
if not isinstance(prepend, (type(None), np.ndarray)):
|
type_check(prepend, (np.ndarray,), "prepend")
|
||||||
raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.")
|
if len(prepend.shape) != 1:
|
||||||
kwargs["prepend"] = prepend
|
raise ValueError("can only prepend 1D arrays.")
|
||||||
|
|
||||||
if append is not None:
|
if append is not None:
|
||||||
if not isinstance(append, (type(None), np.ndarray)):
|
type_check(append, (np.ndarray,), "append")
|
||||||
raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.")
|
if len(append.shape) != 1:
|
||||||
kwargs["append"] = append
|
raise ValueError("can only append 1D arrays.")
|
||||||
|
|
||||||
return method(self, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
|
@ -40,12 +40,14 @@ Examples:
|
||||||
>>> dataset = dataset.map(input_columns="image", operations=transforms_list)
|
>>> dataset = dataset.map(input_columns="image", operations=transforms_list)
|
||||||
>>> dataset = dataset.map(input_columns="label", operations=onehot_op)
|
>>> dataset = dataset.map(input_columns="label", operations=onehot_op)
|
||||||
"""
|
"""
|
||||||
|
import numbers
|
||||||
import mindspore._c_dataengine as cde
|
import mindspore._c_dataengine as cde
|
||||||
|
|
||||||
from .utils import Inter, Border
|
from .utils import Inter, Border
|
||||||
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
|
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
|
||||||
check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
|
check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \
|
||||||
check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp
|
check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \
|
||||||
|
FLOAT_MAX_INTEGER
|
||||||
|
|
||||||
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
|
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
|
||||||
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
|
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
|
||||||
|
@ -57,6 +59,18 @@ DE_C_BORDER_TYPE = {Border.CONSTANT: cde.BorderType.DE_BORDER_CONSTANT,
|
||||||
Border.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC}
|
Border.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_padding(padding):
|
||||||
|
if isinstance(padding, numbers.Number):
|
||||||
|
padding = [padding] * 4
|
||||||
|
if len(padding) == 2:
|
||||||
|
left = right = padding[0]
|
||||||
|
top = bottom = padding[1]
|
||||||
|
padding = (left, top, right, bottom,)
|
||||||
|
if isinstance(padding, list):
|
||||||
|
padding = tuple(padding)
|
||||||
|
return padding
|
||||||
|
|
||||||
|
|
||||||
class Decode(cde.DecodeOp):
|
class Decode(cde.DecodeOp):
|
||||||
"""
|
"""
|
||||||
Decode the input image in RGB mode.
|
Decode the input image in RGB mode.
|
||||||
|
@ -136,16 +150,22 @@ class RandomCrop(cde.RandomCropOp):
|
||||||
|
|
||||||
@check_random_crop
|
@check_random_crop
|
||||||
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
|
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
|
||||||
|
if isinstance(size, int):
|
||||||
|
size = (size, size)
|
||||||
|
if padding is None:
|
||||||
|
padding = (0, 0, 0, 0)
|
||||||
|
else:
|
||||||
|
padding = parse_padding(padding)
|
||||||
|
if isinstance(fill_value, int): # temporary fix
|
||||||
|
fill_value = tuple([fill_value] * 3)
|
||||||
|
border_type = DE_C_BORDER_TYPE[padding_mode]
|
||||||
|
|
||||||
self.size = size
|
self.size = size
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.pad_if_needed = pad_if_needed
|
self.pad_if_needed = pad_if_needed
|
||||||
self.fill_value = fill_value
|
self.fill_value = fill_value
|
||||||
self.padding_mode = padding_mode.value
|
self.padding_mode = padding_mode.value
|
||||||
if padding is None:
|
|
||||||
padding = (0, 0, 0, 0)
|
|
||||||
if isinstance(fill_value, int): # temporary fix
|
|
||||||
fill_value = tuple([fill_value] * 3)
|
|
||||||
border_type = DE_C_BORDER_TYPE[padding_mode]
|
|
||||||
super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value)
|
super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -184,16 +204,23 @@ class RandomCropWithBBox(cde.RandomCropWithBBoxOp):
|
||||||
|
|
||||||
@check_random_crop
|
@check_random_crop
|
||||||
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
|
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
|
||||||
|
if isinstance(size, int):
|
||||||
|
size = (size, size)
|
||||||
|
if padding is None:
|
||||||
|
padding = (0, 0, 0, 0)
|
||||||
|
else:
|
||||||
|
padding = parse_padding(padding)
|
||||||
|
|
||||||
|
if isinstance(fill_value, int): # temporary fix
|
||||||
|
fill_value = tuple([fill_value] * 3)
|
||||||
|
border_type = DE_C_BORDER_TYPE[padding_mode]
|
||||||
|
|
||||||
self.size = size
|
self.size = size
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.pad_if_needed = pad_if_needed
|
self.pad_if_needed = pad_if_needed
|
||||||
self.fill_value = fill_value
|
self.fill_value = fill_value
|
||||||
self.padding_mode = padding_mode.value
|
self.padding_mode = padding_mode.value
|
||||||
if padding is None:
|
|
||||||
padding = (0, 0, 0, 0)
|
|
||||||
if isinstance(fill_value, int): # temporary fix
|
|
||||||
fill_value = tuple([fill_value] * 3)
|
|
||||||
border_type = DE_C_BORDER_TYPE[padding_mode]
|
|
||||||
super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value)
|
super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -292,6 +319,8 @@ class Resize(cde.ResizeOp):
|
||||||
|
|
||||||
@check_resize_interpolation
|
@check_resize_interpolation
|
||||||
def __init__(self, size, interpolation=Inter.LINEAR):
|
def __init__(self, size, interpolation=Inter.LINEAR):
|
||||||
|
if isinstance(size, int):
|
||||||
|
size = (size, size)
|
||||||
self.size = size
|
self.size = size
|
||||||
self.interpolation = interpolation
|
self.interpolation = interpolation
|
||||||
interpoltn = DE_C_INTER_MODE[interpolation]
|
interpoltn = DE_C_INTER_MODE[interpolation]
|
||||||
|
@ -359,6 +388,8 @@ class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp):
|
||||||
@check_random_resize_crop
|
@check_random_resize_crop
|
||||||
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
||||||
interpolation=Inter.BILINEAR, max_attempts=10):
|
interpolation=Inter.BILINEAR, max_attempts=10):
|
||||||
|
if isinstance(size, int):
|
||||||
|
size = (size, size)
|
||||||
self.size = size
|
self.size = size
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.ratio = ratio
|
self.ratio = ratio
|
||||||
|
@ -396,6 +427,8 @@ class RandomResizedCrop(cde.RandomCropAndResizeOp):
|
||||||
@check_random_resize_crop
|
@check_random_resize_crop
|
||||||
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
||||||
interpolation=Inter.BILINEAR, max_attempts=10):
|
interpolation=Inter.BILINEAR, max_attempts=10):
|
||||||
|
if isinstance(size, int):
|
||||||
|
size = (size, size)
|
||||||
self.size = size
|
self.size = size
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.ratio = ratio
|
self.ratio = ratio
|
||||||
|
@ -417,6 +450,8 @@ class CenterCrop(cde.CenterCropOp):
|
||||||
|
|
||||||
@check_crop
|
@check_crop
|
||||||
def __init__(self, size):
|
def __init__(self, size):
|
||||||
|
if isinstance(size, int):
|
||||||
|
size = (size, size)
|
||||||
self.size = size
|
self.size = size
|
||||||
super().__init__(*size)
|
super().__init__(*size)
|
||||||
|
|
||||||
|
@ -442,12 +477,26 @@ class RandomColorAdjust(cde.RandomColorAdjustOp):
|
||||||
|
|
||||||
@check_random_color_adjust
|
@check_random_color_adjust
|
||||||
def __init__(self, brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0)):
|
def __init__(self, brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0)):
|
||||||
|
brightness = self.expand_values(brightness)
|
||||||
|
contrast = self.expand_values(contrast)
|
||||||
|
saturation = self.expand_values(saturation)
|
||||||
|
hue = self.expand_values(hue, center=0, bound=(-0.5, 0.5), non_negative=False)
|
||||||
|
|
||||||
self.brightness = brightness
|
self.brightness = brightness
|
||||||
self.contrast = contrast
|
self.contrast = contrast
|
||||||
self.saturation = saturation
|
self.saturation = saturation
|
||||||
self.hue = hue
|
self.hue = hue
|
||||||
|
|
||||||
super().__init__(*brightness, *contrast, *saturation, *hue)
|
super().__init__(*brightness, *contrast, *saturation, *hue)
|
||||||
|
|
||||||
|
def expand_values(self, value, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True):
|
||||||
|
if isinstance(value, numbers.Number):
|
||||||
|
value = [center - value, center + value]
|
||||||
|
if non_negative:
|
||||||
|
value[0] = max(0, value[0])
|
||||||
|
check_range(value, bound)
|
||||||
|
return (value[0], value[1])
|
||||||
|
|
||||||
|
|
||||||
class RandomRotation(cde.RandomRotationOp):
|
class RandomRotation(cde.RandomRotationOp):
|
||||||
"""
|
"""
|
||||||
|
@ -485,6 +534,8 @@ class RandomRotation(cde.RandomRotationOp):
|
||||||
self.expand = expand
|
self.expand = expand
|
||||||
self.center = center
|
self.center = center
|
||||||
self.fill_value = fill_value
|
self.fill_value = fill_value
|
||||||
|
if isinstance(degrees, numbers.Number):
|
||||||
|
degrees = (-degrees, degrees)
|
||||||
if center is None:
|
if center is None:
|
||||||
center = (-1, -1)
|
center = (-1, -1)
|
||||||
if isinstance(fill_value, int): # temporary fix
|
if isinstance(fill_value, int): # temporary fix
|
||||||
|
@ -584,6 +635,8 @@ class RandomCropDecodeResize(cde.RandomCropDecodeResizeOp):
|
||||||
@check_random_resize_crop
|
@check_random_resize_crop
|
||||||
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
||||||
interpolation=Inter.BILINEAR, max_attempts=10):
|
interpolation=Inter.BILINEAR, max_attempts=10):
|
||||||
|
if isinstance(size, int):
|
||||||
|
size = (size, size)
|
||||||
self.size = size
|
self.size = size
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.ratio = ratio
|
self.ratio = ratio
|
||||||
|
@ -623,12 +676,14 @@ class Pad(cde.PadOp):
|
||||||
|
|
||||||
@check_pad
|
@check_pad
|
||||||
def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT):
|
def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT):
|
||||||
self.padding = padding
|
padding = parse_padding(padding)
|
||||||
self.fill_value = fill_value
|
|
||||||
self.padding_mode = padding_mode
|
|
||||||
if isinstance(fill_value, int): # temporary fix
|
if isinstance(fill_value, int): # temporary fix
|
||||||
fill_value = tuple([fill_value] * 3)
|
fill_value = tuple([fill_value] * 3)
|
||||||
padding_mode = DE_C_BORDER_TYPE[padding_mode]
|
padding_mode = DE_C_BORDER_TYPE[padding_mode]
|
||||||
|
|
||||||
|
self.padding = padding
|
||||||
|
self.fill_value = fill_value
|
||||||
|
self.padding_mode = padding_mode
|
||||||
super().__init__(*padding, padding_mode, *fill_value)
|
super().__init__(*padding, padding_mode, *fill_value)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from . import py_transforms_util as util
|
from . import py_transforms_util as util
|
||||||
|
from .c_transforms import parse_padding
|
||||||
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
|
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
|
||||||
check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \
|
check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \
|
||||||
check_transforms_list, check_random_apply, check_ten_crop, check_num_channels, check_pad, \
|
check_transforms_list, check_random_apply, check_ten_crop, check_num_channels, check_pad, \
|
||||||
|
@ -295,6 +296,10 @@ class RandomCrop:
|
||||||
|
|
||||||
@check_random_crop
|
@check_random_crop
|
||||||
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
|
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
|
||||||
|
if padding is None:
|
||||||
|
padding = (0, 0, 0, 0)
|
||||||
|
else:
|
||||||
|
padding = parse_padding(padding)
|
||||||
self.size = size
|
self.size = size
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.pad_if_needed = pad_if_needed
|
self.pad_if_needed = pad_if_needed
|
||||||
|
@ -753,6 +758,8 @@ class TenCrop:
|
||||||
|
|
||||||
@check_ten_crop
|
@check_ten_crop
|
||||||
def __init__(self, size, use_vertical_flip=False):
|
def __init__(self, size, use_vertical_flip=False):
|
||||||
|
if isinstance(size, int):
|
||||||
|
size = (size, size)
|
||||||
self.size = size
|
self.size = size
|
||||||
self.use_vertical_flip = use_vertical_flip
|
self.use_vertical_flip = use_vertical_flip
|
||||||
|
|
||||||
|
@ -877,6 +884,8 @@ class Pad:
|
||||||
|
|
||||||
@check_pad
|
@check_pad
|
||||||
def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT):
|
def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT):
|
||||||
|
parse_padding(padding)
|
||||||
|
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.fill_value = fill_value
|
self.fill_value = fill_value
|
||||||
self.padding_mode = DE_PY_BORDER_TYPE[padding_mode]
|
self.padding_mode = DE_PY_BORDER_TYPE[padding_mode]
|
||||||
|
@ -1129,56 +1138,23 @@ class RandomAffine:
|
||||||
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Inter.NEAREST, fill_value=0):
|
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Inter.NEAREST, fill_value=0):
|
||||||
# Parameter checking
|
# Parameter checking
|
||||||
# rotation
|
# rotation
|
||||||
if isinstance(degrees, numbers.Number):
|
|
||||||
if degrees < 0:
|
|
||||||
raise ValueError("If degrees is a single number, it must be positive.")
|
|
||||||
self.degrees = (-degrees, degrees)
|
|
||||||
elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
|
|
||||||
self.degrees = degrees
|
|
||||||
else:
|
|
||||||
raise TypeError("If degrees is a list or tuple, it must be of length 2.")
|
|
||||||
|
|
||||||
# translation
|
|
||||||
if translate is not None:
|
|
||||||
if isinstance(translate, (tuple, list)) and len(translate) == 2:
|
|
||||||
for t in translate:
|
|
||||||
if t < 0.0 or t > 1.0:
|
|
||||||
raise ValueError("translation values should be between 0 and 1")
|
|
||||||
else:
|
|
||||||
raise TypeError("translate should be a list or tuple of length 2.")
|
|
||||||
self.translate = translate
|
|
||||||
|
|
||||||
# scale
|
|
||||||
if scale is not None:
|
|
||||||
if isinstance(scale, (tuple, list)) and len(scale) == 2:
|
|
||||||
for s in scale:
|
|
||||||
if s <= 0:
|
|
||||||
raise ValueError("scale values should be positive")
|
|
||||||
else:
|
|
||||||
raise TypeError("scale should be a list or tuple of length 2.")
|
|
||||||
self.scale_ranges = scale
|
|
||||||
|
|
||||||
# shear
|
|
||||||
if shear is not None:
|
if shear is not None:
|
||||||
if isinstance(shear, numbers.Number):
|
if isinstance(shear, numbers.Number):
|
||||||
if shear < 0:
|
shear = (-1 * shear, shear)
|
||||||
raise ValueError("If shear is a single number, it must be positive.")
|
|
||||||
self.shear = (-1 * shear, shear)
|
|
||||||
elif isinstance(shear, (tuple, list)) and (len(shear) == 2 or len(shear) == 4):
|
|
||||||
# X-Axis shear with [min, max]
|
|
||||||
if len(shear) == 2:
|
|
||||||
self.shear = [shear[0], shear[1], 0., 0.]
|
|
||||||
elif len(shear) == 4:
|
|
||||||
self.shear = [s for s in shear]
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("shear should be a list or tuple and it must be of length 2 or 4.")
|
if len(shear) == 2:
|
||||||
else:
|
shear = [shear[0], shear[1], 0., 0.]
|
||||||
self.shear = shear
|
elif len(shear) == 4:
|
||||||
|
shear = [s for s in shear]
|
||||||
|
|
||||||
# resample
|
if isinstance(degrees, numbers.Number):
|
||||||
|
degrees = (-degrees, degrees)
|
||||||
|
|
||||||
|
self.degrees = degrees
|
||||||
|
self.translate = translate
|
||||||
|
self.scale_ranges = scale
|
||||||
|
self.shear = shear
|
||||||
self.resample = DE_PY_INTER_MODE[resample]
|
self.resample = DE_PY_INTER_MODE[resample]
|
||||||
|
|
||||||
# fill_value
|
|
||||||
self.fill_value = fill_value
|
self.fill_value = fill_value
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -15,13 +15,15 @@
|
||||||
"""
|
"""
|
||||||
Testing the bounding box augment op in DE
|
Testing the bounding box augment op in DE
|
||||||
"""
|
"""
|
||||||
from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \
|
|
||||||
config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore.log as logger
|
import mindspore.log as logger
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
||||||
|
|
||||||
|
from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \
|
||||||
|
config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5
|
||||||
|
|
||||||
GENERATE_GOLDEN = False
|
GENERATE_GOLDEN = False
|
||||||
|
|
||||||
# updated VOC dataset with correct annotations
|
# updated VOC dataset with correct annotations
|
||||||
|
@ -241,7 +243,7 @@ def test_bounding_box_augment_invalid_ratio_c():
|
||||||
operations=[test_op]) # Add column for "annotation"
|
operations=[test_op]) # Add column for "annotation"
|
||||||
except ValueError as error:
|
except ValueError as error:
|
||||||
logger.info("Got an exception in DE: {}".format(str(error)))
|
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||||
assert "Input is not" in str(error)
|
assert "Input ratio is not within the required interval of (0.0 to 1.0)." in str(error)
|
||||||
|
|
||||||
|
|
||||||
def test_bounding_box_augment_invalid_bounds_c():
|
def test_bounding_box_augment_invalid_bounds_c():
|
||||||
|
|
|
@ -17,6 +17,7 @@ import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
|
|
||||||
|
|
||||||
# generates 1 column [0], [0, 1], ..., [0, ..., n-1]
|
# generates 1 column [0], [0, 1], ..., [0, ..., n-1]
|
||||||
def generate_sequential(n):
|
def generate_sequential(n):
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
|
@ -99,12 +100,12 @@ def test_bucket_batch_invalid_input():
|
||||||
with pytest.raises(TypeError) as info:
|
with pytest.raises(TypeError) as info:
|
||||||
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
|
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
|
||||||
None, None, invalid_type_pad_to_bucket_boundary)
|
None, None, invalid_type_pad_to_bucket_boundary)
|
||||||
assert "Wrong input type for pad_to_bucket_boundary, should be <class 'bool'>" in str(info.value)
|
assert "Argument pad_to_bucket_boundary with value \"\" is not of type (<class \'bool\'>,)." in str(info.value)
|
||||||
|
|
||||||
with pytest.raises(TypeError) as info:
|
with pytest.raises(TypeError) as info:
|
||||||
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
|
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
|
||||||
None, None, False, invalid_type_drop_remainder)
|
None, None, False, invalid_type_drop_remainder)
|
||||||
assert "Wrong input type for drop_remainder, should be <class 'bool'>" in str(info.value)
|
assert "Argument drop_remainder with value \"\" is not of type (<class 'bool'>,)." in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_bucket_batch_multi_bucket_no_padding():
|
def test_bucket_batch_multi_bucket_no_padding():
|
||||||
|
@ -272,7 +273,6 @@ def test_bucket_batch_default_pad():
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0],
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0],
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]]
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]]
|
||||||
|
|
||||||
|
|
||||||
output = []
|
output = []
|
||||||
for data in dataset.create_dict_iterator():
|
for data in dataset.create_dict_iterator():
|
||||||
output.append(data["col1"].tolist())
|
output.append(data["col1"].tolist())
|
||||||
|
|
|
@ -163,18 +163,11 @@ def test_concatenate_op_negative_axis():
|
||||||
|
|
||||||
|
|
||||||
def test_concatenate_op_incorrect_input_dim():
|
def test_concatenate_op_incorrect_input_dim():
|
||||||
def gen():
|
|
||||||
yield (np.array(["ss", "ad"], dtype='S'),)
|
|
||||||
|
|
||||||
prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S')
|
prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S')
|
||||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
|
||||||
concatenate_op = data_trans.Concatenate(0, prepend_tensor)
|
|
||||||
|
|
||||||
data = data.map(input_columns=["col"], operations=concatenate_op)
|
with pytest.raises(ValueError) as error_info:
|
||||||
with pytest.raises(RuntimeError) as error_info:
|
data_trans.Concatenate(0, prepend_tensor)
|
||||||
for _ in data:
|
assert "can only prepend 1D arrays." in repr(error_info.value)
|
||||||
pass
|
|
||||||
assert "Only 1D tensors supported" in repr(error_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -28,9 +28,9 @@ def test_exception_01():
|
||||||
"""
|
"""
|
||||||
logger.info("test_exception_01")
|
logger.info("test_exception_01")
|
||||||
data = ds.TFRecordDataset(DATA_DIR, columns_list=["image"])
|
data = ds.TFRecordDataset(DATA_DIR, columns_list=["image"])
|
||||||
with pytest.raises(ValueError) as info:
|
with pytest.raises(TypeError) as info:
|
||||||
data = data.map(input_columns=["image"], operations=vision.Resize(100, 100))
|
data.map(input_columns=["image"], operations=vision.Resize(100, 100))
|
||||||
assert "Invalid interpolation mode." in str(info.value)
|
assert "Argument interpolation with value 100 is not of type (<enum 'Inter'>,)" in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_exception_02():
|
def test_exception_02():
|
||||||
|
@ -40,8 +40,8 @@ def test_exception_02():
|
||||||
logger.info("test_exception_02")
|
logger.info("test_exception_02")
|
||||||
num_samples = -1
|
num_samples = -1
|
||||||
with pytest.raises(ValueError) as info:
|
with pytest.raises(ValueError) as info:
|
||||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
||||||
assert "num_samples cannot be less than 0" in str(info.value)
|
assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(info.value)
|
||||||
|
|
||||||
num_samples = 1
|
num_samples = 1
|
||||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
||||||
|
|
|
@ -23,7 +23,8 @@ import mindspore.dataset.text as text
|
||||||
def test_demo_basic_from_dataset():
|
def test_demo_basic_from_dataset():
|
||||||
""" this is a tutorial on how from_dataset should be used in a normal use case"""
|
""" this is a tutorial on how from_dataset should be used in a normal use case"""
|
||||||
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
|
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
|
||||||
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"],
|
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None,
|
||||||
|
special_tokens=["<pad>", "<unk>"],
|
||||||
special_first=True)
|
special_first=True)
|
||||||
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
||||||
res = []
|
res = []
|
||||||
|
@ -127,15 +128,16 @@ def test_from_dataset_exceptions():
|
||||||
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
|
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
|
||||||
vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k)
|
vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k)
|
||||||
assert isinstance(vocab.text.Vocab)
|
assert isinstance(vocab.text.Vocab)
|
||||||
except ValueError as e:
|
except (TypeError, ValueError, RuntimeError) as e:
|
||||||
assert s in str(e), str(e)
|
assert s in str(e), str(e)
|
||||||
|
|
||||||
test_config("text", (), 1, "freq_range needs to be either None or a tuple of 2 integers")
|
test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.")
|
||||||
test_config("text", (2, 3), 1.2345, "top_k needs to be a positive integer")
|
test_config("text", (2, 3), 1.2345,
|
||||||
test_config(23, (2, 3), 1.2345, "columns need to be a list of strings")
|
"Argument top_k with value 1.2345 is not of type (<class 'int'>, <class 'NoneType'>)")
|
||||||
test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b")
|
test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (<class 'str'>,)")
|
||||||
test_config("text", (2, 3), 0, "top_k needs to be a positive integer")
|
test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)")
|
||||||
test_config([123], (2, 3), 0, "columns need to be a list of strings")
|
test_config("text", (2, 3), 0, "top_k needs to be positive number")
|
||||||
|
test_config([123], (2, 3), 0, "top_k needs to be positive number")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -73,6 +73,7 @@ def test_linear_transformation_op(plot=False):
|
||||||
if plot:
|
if plot:
|
||||||
visualize_list(image, image_transformed)
|
visualize_list(image, image_transformed)
|
||||||
|
|
||||||
|
|
||||||
def test_linear_transformation_md5():
|
def test_linear_transformation_md5():
|
||||||
"""
|
"""
|
||||||
Test LinearTransformation op: valid params (transformation_matrix, mean_vector)
|
Test LinearTransformation op: valid params (transformation_matrix, mean_vector)
|
||||||
|
@ -102,6 +103,7 @@ def test_linear_transformation_md5():
|
||||||
filename = "linear_transformation_01_result.npz"
|
filename = "linear_transformation_01_result.npz"
|
||||||
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
|
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
|
||||||
|
|
||||||
|
|
||||||
def test_linear_transformation_exception_01():
|
def test_linear_transformation_exception_01():
|
||||||
"""
|
"""
|
||||||
Test LinearTransformation op: transformation_matrix is not provided
|
Test LinearTransformation op: transformation_matrix is not provided
|
||||||
|
@ -126,9 +128,10 @@ def test_linear_transformation_exception_01():
|
||||||
]
|
]
|
||||||
transform = py_vision.ComposeOp(transforms)
|
transform = py_vision.ComposeOp(transforms)
|
||||||
data1 = data1.map(input_columns=["image"], operations=transform())
|
data1 = data1.map(input_columns=["image"], operations=transform())
|
||||||
except ValueError as e:
|
except TypeError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "not provided" in str(e)
|
assert "Argument transformation_matrix with value None is not of type (<class 'numpy.ndarray'>,)" in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_linear_transformation_exception_02():
|
def test_linear_transformation_exception_02():
|
||||||
"""
|
"""
|
||||||
|
@ -154,9 +157,10 @@ def test_linear_transformation_exception_02():
|
||||||
]
|
]
|
||||||
transform = py_vision.ComposeOp(transforms)
|
transform = py_vision.ComposeOp(transforms)
|
||||||
data1 = data1.map(input_columns=["image"], operations=transform())
|
data1 = data1.map(input_columns=["image"], operations=transform())
|
||||||
except ValueError as e:
|
except TypeError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "not provided" in str(e)
|
assert "Argument mean_vector with value None is not of type (<class 'numpy.ndarray'>,)" in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_linear_transformation_exception_03():
|
def test_linear_transformation_exception_03():
|
||||||
"""
|
"""
|
||||||
|
@ -187,6 +191,7 @@ def test_linear_transformation_exception_03():
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "square matrix" in str(e)
|
assert "square matrix" in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_linear_transformation_exception_04():
|
def test_linear_transformation_exception_04():
|
||||||
"""
|
"""
|
||||||
Test LinearTransformation op: mean_vector does not match dimension of transformation_matrix
|
Test LinearTransformation op: mean_vector does not match dimension of transformation_matrix
|
||||||
|
@ -199,7 +204,7 @@ def test_linear_transformation_exception_04():
|
||||||
weight = 50
|
weight = 50
|
||||||
dim = 3 * height * weight
|
dim = 3 * height * weight
|
||||||
transformation_matrix = np.ones([dim, dim])
|
transformation_matrix = np.ones([dim, dim])
|
||||||
mean_vector = np.zeros(dim-1)
|
mean_vector = np.zeros(dim - 1)
|
||||||
|
|
||||||
# Generate dataset
|
# Generate dataset
|
||||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||||
|
@ -216,6 +221,7 @@ def test_linear_transformation_exception_04():
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "should match" in str(e)
|
assert "should match" in str(e)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_linear_transformation_op(plot=True)
|
test_linear_transformation_op(plot=True)
|
||||||
test_linear_transformation_md5()
|
test_linear_transformation_md5()
|
||||||
|
|
|
@ -184,24 +184,26 @@ def test_minddataset_invalidate_num_shards():
|
||||||
create_cv_mindrecord(1)
|
create_cv_mindrecord(1)
|
||||||
columns_list = ["data", "label"]
|
columns_list = ["data", "label"]
|
||||||
num_readers = 4
|
num_readers = 4
|
||||||
with pytest.raises(Exception, match="shard_id is invalid, "):
|
with pytest.raises(Exception) as error_info:
|
||||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2)
|
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2)
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in data_set.create_dict_iterator():
|
for _ in data_set.create_dict_iterator():
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
|
assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info)
|
||||||
|
|
||||||
os.remove(CV_FILE_NAME)
|
os.remove(CV_FILE_NAME)
|
||||||
os.remove("{}.db".format(CV_FILE_NAME))
|
os.remove("{}.db".format(CV_FILE_NAME))
|
||||||
|
|
||||||
|
|
||||||
def test_minddataset_invalidate_shard_id():
|
def test_minddataset_invalidate_shard_id():
|
||||||
create_cv_mindrecord(1)
|
create_cv_mindrecord(1)
|
||||||
columns_list = ["data", "label"]
|
columns_list = ["data", "label"]
|
||||||
num_readers = 4
|
num_readers = 4
|
||||||
with pytest.raises(Exception, match="shard_id is invalid, "):
|
with pytest.raises(Exception) as error_info:
|
||||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1)
|
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1)
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in data_set.create_dict_iterator():
|
for _ in data_set.create_dict_iterator():
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
|
assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info)
|
||||||
os.remove(CV_FILE_NAME)
|
os.remove(CV_FILE_NAME)
|
||||||
os.remove("{}.db".format(CV_FILE_NAME))
|
os.remove("{}.db".format(CV_FILE_NAME))
|
||||||
|
|
||||||
|
@ -210,17 +212,19 @@ def test_minddataset_shard_id_bigger_than_num_shard():
|
||||||
create_cv_mindrecord(1)
|
create_cv_mindrecord(1)
|
||||||
columns_list = ["data", "label"]
|
columns_list = ["data", "label"]
|
||||||
num_readers = 4
|
num_readers = 4
|
||||||
with pytest.raises(Exception, match="shard_id is invalid, "):
|
with pytest.raises(Exception) as error_info:
|
||||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2)
|
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2)
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in data_set.create_dict_iterator():
|
for _ in data_set.create_dict_iterator():
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
|
assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info)
|
||||||
|
|
||||||
with pytest.raises(Exception, match="shard_id is invalid, "):
|
with pytest.raises(Exception) as error_info:
|
||||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
|
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in data_set.create_dict_iterator():
|
for _ in data_set.create_dict_iterator():
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
|
assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info)
|
||||||
|
|
||||||
os.remove(CV_FILE_NAME)
|
os.remove(CV_FILE_NAME)
|
||||||
os.remove("{}.db".format(CV_FILE_NAME))
|
os.remove("{}.db".format(CV_FILE_NAME))
|
||||||
|
|
|
@ -15,9 +15,9 @@
|
||||||
"""
|
"""
|
||||||
Testing Ngram in mindspore.dataset
|
Testing Ngram in mindspore.dataset
|
||||||
"""
|
"""
|
||||||
|
import numpy as np
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
import mindspore.dataset.text as text
|
import mindspore.dataset.text as text
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_ngrams():
|
def test_multiple_ngrams():
|
||||||
|
@ -61,7 +61,7 @@ def test_simple_ngram():
|
||||||
yield (np.array(line.split(" "), dtype='S'),)
|
yield (np.array(line.split(" "), dtype='S'),)
|
||||||
|
|
||||||
dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"])
|
dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"])
|
||||||
dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=None))
|
dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=" "))
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
for data in dataset.create_dict_iterator():
|
for data in dataset.create_dict_iterator():
|
||||||
|
@ -72,7 +72,7 @@ def test_simple_ngram():
|
||||||
def test_corner_cases():
|
def test_corner_cases():
|
||||||
""" testing various corner cases and exceptions"""
|
""" testing various corner cases and exceptions"""
|
||||||
|
|
||||||
def test_config(input_line, output_line, n, l_pad=None, r_pad=None, sep=None):
|
def test_config(input_line, output_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "):
|
||||||
def gen(texts):
|
def gen(texts):
|
||||||
yield (np.array(texts.split(" "), dtype='S'),)
|
yield (np.array(texts.split(" "), dtype='S'),)
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ def test_corner_cases():
|
||||||
try:
|
try:
|
||||||
test_config("Yours to Discover", "", [0, [1]])
|
test_config("Yours to Discover", "", [0, [1]])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
assert "ngram needs to be a positive number" in str(e)
|
assert "Argument gram[1] with value [1] is not of type (<class 'int'>,)" in str(e)
|
||||||
# test empty n
|
# test empty n
|
||||||
try:
|
try:
|
||||||
test_config("Yours to Discover", "", [])
|
test_config("Yours to Discover", "", [])
|
||||||
|
|
|
@ -279,7 +279,7 @@ def test_normalize_exception_invalid_range_py():
|
||||||
_ = py_vision.Normalize([0.75, 1.25, 0.5], [0.1, 0.18, 1.32])
|
_ = py_vision.Normalize([0.75, 1.25, 0.5], [0.1, 0.18, 1.32])
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "Input is not within the required range" in str(e)
|
assert "Input mean_value is not within the required interval of (0.0 to 1.0)." in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_grayscale_md5_01():
|
def test_normalize_grayscale_md5_01():
|
||||||
|
|
|
@ -61,6 +61,10 @@ def test_pad_end_exceptions():
|
||||||
pad_compare([3, 4, 5], ["2"], 1, [])
|
pad_compare([3, 4, 5], ["2"], 1, [])
|
||||||
assert "a value in the list is not an integer." in str(info.value)
|
assert "a value in the list is not an integer." in str(info.value)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError) as info:
|
||||||
|
pad_compare([1, 2], 3, -1, [1, 2, -1])
|
||||||
|
assert "Argument pad_end with value 3 is not of type (<class 'list'>,)" in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_pad_end_basics()
|
test_pad_end_basics()
|
||||||
|
|
|
@ -103,7 +103,7 @@ def test_random_affine_exception_negative_degrees():
|
||||||
_ = py_vision.RandomAffine(degrees=-15)
|
_ = py_vision.RandomAffine(degrees=-15)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert str(e) == "If degrees is a single number, it cannot be negative."
|
assert str(e) == "Input degrees is not within the required interval of (0 to inf)."
|
||||||
|
|
||||||
|
|
||||||
def test_random_affine_exception_translation_range():
|
def test_random_affine_exception_translation_range():
|
||||||
|
@ -115,7 +115,7 @@ def test_random_affine_exception_translation_range():
|
||||||
_ = py_vision.RandomAffine(degrees=15, translate=(0.1, 1.5))
|
_ = py_vision.RandomAffine(degrees=15, translate=(0.1, 1.5))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert str(e) == "translation values should be between 0 and 1"
|
assert str(e) == "Input translate at 1 is not within the required interval of (0.0 to 1.0)."
|
||||||
|
|
||||||
|
|
||||||
def test_random_affine_exception_scale_value():
|
def test_random_affine_exception_scale_value():
|
||||||
|
@ -127,7 +127,7 @@ def test_random_affine_exception_scale_value():
|
||||||
_ = py_vision.RandomAffine(degrees=15, scale=(0.0, 1.1))
|
_ = py_vision.RandomAffine(degrees=15, scale=(0.0, 1.1))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert str(e) == "scale values should be positive"
|
assert str(e) == "Input scale[0] must be greater than 0."
|
||||||
|
|
||||||
|
|
||||||
def test_random_affine_exception_shear_value():
|
def test_random_affine_exception_shear_value():
|
||||||
|
@ -139,7 +139,7 @@ def test_random_affine_exception_shear_value():
|
||||||
_ = py_vision.RandomAffine(degrees=15, shear=-5)
|
_ = py_vision.RandomAffine(degrees=15, shear=-5)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert str(e) == "If shear is a single number, it must be positive."
|
assert str(e) == "Input shear must be greater than 0."
|
||||||
|
|
||||||
|
|
||||||
def test_random_affine_exception_degrees_size():
|
def test_random_affine_exception_degrees_size():
|
||||||
|
@ -165,7 +165,9 @@ def test_random_affine_exception_translate_size():
|
||||||
_ = py_vision.RandomAffine(degrees=15, translate=(0.1))
|
_ = py_vision.RandomAffine(degrees=15, translate=(0.1))
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert str(e) == "translate should be a list or tuple of length 2."
|
assert str(
|
||||||
|
e) == "Argument translate with value 0.1 is not of type (<class 'list'>," \
|
||||||
|
" <class 'tuple'>)."
|
||||||
|
|
||||||
|
|
||||||
def test_random_affine_exception_scale_size():
|
def test_random_affine_exception_scale_size():
|
||||||
|
@ -178,7 +180,8 @@ def test_random_affine_exception_scale_size():
|
||||||
_ = py_vision.RandomAffine(degrees=15, scale=(0.5))
|
_ = py_vision.RandomAffine(degrees=15, scale=(0.5))
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert str(e) == "scale should be a list or tuple of length 2."
|
assert str(e) == "Argument scale with value 0.5 is not of type (<class 'tuple'>," \
|
||||||
|
" <class 'list'>)."
|
||||||
|
|
||||||
|
|
||||||
def test_random_affine_exception_shear_size():
|
def test_random_affine_exception_shear_size():
|
||||||
|
@ -191,7 +194,7 @@ def test_random_affine_exception_shear_size():
|
||||||
_ = py_vision.RandomAffine(degrees=15, shear=(-5, 5, 10))
|
_ = py_vision.RandomAffine(degrees=15, shear=(-5, 5, 10))
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert str(e) == "shear should be a list or tuple and it must be of length 2 or 4."
|
assert str(e) == "shear must be of length 2 or 4."
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -97,7 +97,7 @@ def test_random_color_md5():
|
||||||
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
|
|
||||||
transforms = F.ComposeOp([F.Decode(),
|
transforms = F.ComposeOp([F.Decode(),
|
||||||
F.RandomColor((0.5, 1.5)),
|
F.RandomColor((0.1, 1.9)),
|
||||||
F.ToTensor()])
|
F.ToTensor()])
|
||||||
|
|
||||||
data = data.map(input_columns="image", operations=transforms())
|
data = data.map(input_columns="image", operations=transforms())
|
||||||
|
|
|
@ -232,7 +232,7 @@ def test_random_crop_and_resize_04_c():
|
||||||
data = data.map(input_columns=["image"], operations=random_crop_and_resize_op)
|
data = data.map(input_columns=["image"], operations=random_crop_and_resize_op)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "Input range is not valid" in str(e)
|
assert "Input is not within the required interval of (0 to 16777216)." in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_random_crop_and_resize_04_py():
|
def test_random_crop_and_resize_04_py():
|
||||||
|
@ -255,7 +255,7 @@ def test_random_crop_and_resize_04_py():
|
||||||
data = data.map(input_columns=["image"], operations=transform())
|
data = data.map(input_columns=["image"], operations=transform())
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "Input range is not valid" in str(e)
|
assert "Input is not within the required interval of (0 to 16777216)." in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_random_crop_and_resize_05_c():
|
def test_random_crop_and_resize_05_c():
|
||||||
|
@ -275,7 +275,7 @@ def test_random_crop_and_resize_05_c():
|
||||||
data = data.map(input_columns=["image"], operations=random_crop_and_resize_op)
|
data = data.map(input_columns=["image"], operations=random_crop_and_resize_op)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "Input range is not valid" in str(e)
|
assert "Input is not within the required interval of (0 to 16777216)." in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_random_crop_and_resize_05_py():
|
def test_random_crop_and_resize_05_py():
|
||||||
|
@ -298,7 +298,7 @@ def test_random_crop_and_resize_05_py():
|
||||||
data = data.map(input_columns=["image"], operations=transform())
|
data = data.map(input_columns=["image"], operations=transform())
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "Input range is not valid" in str(e)
|
assert "Input is not within the required interval of (0 to 16777216)." in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_random_crop_and_resize_comp(plot=False):
|
def test_random_crop_and_resize_comp(plot=False):
|
||||||
|
|
|
@ -159,7 +159,7 @@ def test_random_resized_crop_with_bbox_op_invalid_c():
|
||||||
|
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
logger.info("Got an exception in DE: {}".format(str(err)))
|
logger.info("Got an exception in DE: {}".format(str(err)))
|
||||||
assert "Input range is not valid" in str(err)
|
assert "Input is not within the required interval of (0 to 16777216)." in str(err)
|
||||||
|
|
||||||
|
|
||||||
def test_random_resized_crop_with_bbox_op_invalid2_c():
|
def test_random_resized_crop_with_bbox_op_invalid2_c():
|
||||||
|
@ -185,7 +185,7 @@ def test_random_resized_crop_with_bbox_op_invalid2_c():
|
||||||
|
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
logger.info("Got an exception in DE: {}".format(str(err)))
|
logger.info("Got an exception in DE: {}".format(str(err)))
|
||||||
assert "Input range is not valid" in str(err)
|
assert "Input is not within the required interval of (0 to 16777216)." in str(err)
|
||||||
|
|
||||||
|
|
||||||
def test_random_resized_crop_with_bbox_op_bad_c():
|
def test_random_resized_crop_with_bbox_op_bad_c():
|
||||||
|
|
|
@ -179,7 +179,7 @@ def test_random_grayscale_invalid_param():
|
||||||
data = data.map(input_columns=["image"], operations=transform())
|
data = data.map(input_columns=["image"], operations=transform())
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "Input is not within the required range" in str(e)
|
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_random_grayscale_valid_prob(True)
|
test_random_grayscale_valid_prob(True)
|
||||||
|
|
|
@ -141,7 +141,7 @@ def test_random_horizontal_invalid_prob_c():
|
||||||
data = data.map(input_columns=["image"], operations=random_horizontal_op)
|
data = data.map(input_columns=["image"], operations=random_horizontal_op)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "Input is not" in str(e)
|
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_random_horizontal_invalid_prob_py():
|
def test_random_horizontal_invalid_prob_py():
|
||||||
|
@ -164,7 +164,7 @@ def test_random_horizontal_invalid_prob_py():
|
||||||
data = data.map(input_columns=["image"], operations=transform())
|
data = data.map(input_columns=["image"], operations=transform())
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "Input is not" in str(e)
|
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_random_horizontal_comp(plot=False):
|
def test_random_horizontal_comp(plot=False):
|
||||||
|
|
|
@ -190,7 +190,7 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c():
|
||||||
operations=[test_op]) # Add column for "annotation"
|
operations=[test_op]) # Add column for "annotation"
|
||||||
except ValueError as error:
|
except ValueError as error:
|
||||||
logger.info("Got an exception in DE: {}".format(str(error)))
|
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||||
assert "Input is not" in str(error)
|
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(error)
|
||||||
|
|
||||||
|
|
||||||
def test_random_horizontal_flip_with_bbox_invalid_bounds_c():
|
def test_random_horizontal_flip_with_bbox_invalid_bounds_c():
|
||||||
|
|
|
@ -107,7 +107,7 @@ def test_random_perspective_exception_distortion_scale_range():
|
||||||
_ = py_vision.RandomPerspective(distortion_scale=1.5)
|
_ = py_vision.RandomPerspective(distortion_scale=1.5)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert str(e) == "Input is not within the required range"
|
assert str(e) == "Input distortion_scale is not within the required interval of (0.0 to 1.0)."
|
||||||
|
|
||||||
|
|
||||||
def test_random_perspective_exception_prob_range():
|
def test_random_perspective_exception_prob_range():
|
||||||
|
@ -119,7 +119,7 @@ def test_random_perspective_exception_prob_range():
|
||||||
_ = py_vision.RandomPerspective(prob=1.2)
|
_ = py_vision.RandomPerspective(prob=1.2)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert str(e) == "Input is not within the required range"
|
assert str(e) == "Input prob is not within the required interval of (0.0 to 1.0)."
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -163,7 +163,7 @@ def test_random_resize_with_bbox_op_invalid_c():
|
||||||
|
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
logger.info("Got an exception in DE: {}".format(str(err)))
|
logger.info("Got an exception in DE: {}".format(str(err)))
|
||||||
assert "Input is not" in str(err)
|
assert "Input is not within the required interval of (1 to 16777216)." in str(err)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# one of the size values is zero
|
# one of the size values is zero
|
||||||
|
@ -171,7 +171,7 @@ def test_random_resize_with_bbox_op_invalid_c():
|
||||||
|
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
logger.info("Got an exception in DE: {}".format(str(err)))
|
logger.info("Got an exception in DE: {}".format(str(err)))
|
||||||
assert "Input is not" in str(err)
|
assert "Input size at dim 0 is not within the required interval of (1 to 2147483647)." in str(err)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# negative value for resize
|
# negative value for resize
|
||||||
|
@ -179,7 +179,7 @@ def test_random_resize_with_bbox_op_invalid_c():
|
||||||
|
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
logger.info("Got an exception in DE: {}".format(str(err)))
|
logger.info("Got an exception in DE: {}".format(str(err)))
|
||||||
assert "Input is not" in str(err)
|
assert "Input is not within the required interval of (1 to 16777216)." in str(err)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# invalid input shape
|
# invalid input shape
|
||||||
|
|
|
@ -97,7 +97,7 @@ def test_random_sharpness_md5():
|
||||||
# define map operations
|
# define map operations
|
||||||
transforms = [
|
transforms = [
|
||||||
F.Decode(),
|
F.Decode(),
|
||||||
F.RandomSharpness((0.5, 1.5)),
|
F.RandomSharpness((0.1, 1.9)),
|
||||||
F.ToTensor()
|
F.ToTensor()
|
||||||
]
|
]
|
||||||
transform = F.ComposeOp(transforms)
|
transform = F.ComposeOp(transforms)
|
||||||
|
|
|
@ -141,7 +141,7 @@ def test_random_vertical_invalid_prob_c():
|
||||||
data = data.map(input_columns=["image"], operations=random_horizontal_op)
|
data = data.map(input_columns=["image"], operations=random_horizontal_op)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "Input is not" in str(e)
|
assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_random_vertical_invalid_prob_py():
|
def test_random_vertical_invalid_prob_py():
|
||||||
|
@ -163,7 +163,7 @@ def test_random_vertical_invalid_prob_py():
|
||||||
data = data.map(input_columns=["image"], operations=transform())
|
data = data.map(input_columns=["image"], operations=transform())
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "Input is not" in str(e)
|
assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_random_vertical_comp(plot=False):
|
def test_random_vertical_comp(plot=False):
|
||||||
|
|
|
@ -191,7 +191,7 @@ def test_random_vertical_flip_with_bbox_op_invalid_c():
|
||||||
|
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
logger.info("Got an exception in DE: {}".format(str(err)))
|
logger.info("Got an exception in DE: {}".format(str(err)))
|
||||||
assert "Input is not" in str(err)
|
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(err)
|
||||||
|
|
||||||
|
|
||||||
def test_random_vertical_flip_with_bbox_op_bad_c():
|
def test_random_vertical_flip_with_bbox_op_bad_c():
|
||||||
|
|
|
@ -150,7 +150,7 @@ def test_resize_with_bbox_op_invalid_c():
|
||||||
# invalid interpolation value
|
# invalid interpolation value
|
||||||
c_vision.ResizeWithBBox(400, interpolation="invalid")
|
c_vision.ResizeWithBBox(400, interpolation="invalid")
|
||||||
|
|
||||||
except ValueError as err:
|
except TypeError as err:
|
||||||
logger.info("Got an exception in DE: {}".format(str(err)))
|
logger.info("Got an exception in DE: {}".format(str(err)))
|
||||||
assert "interpolation" in str(err)
|
assert "interpolation" in str(err)
|
||||||
|
|
||||||
|
|
|
@ -154,7 +154,7 @@ def test_shuffle_exception_01():
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "buffer_size" in str(e)
|
assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_shuffle_exception_02():
|
def test_shuffle_exception_02():
|
||||||
|
@ -172,7 +172,7 @@ def test_shuffle_exception_02():
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "buffer_size" in str(e)
|
assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_shuffle_exception_03():
|
def test_shuffle_exception_03():
|
||||||
|
@ -190,7 +190,7 @@ def test_shuffle_exception_03():
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "buffer_size" in str(e)
|
assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_shuffle_exception_05():
|
def test_shuffle_exception_05():
|
||||||
|
|
|
@ -62,7 +62,7 @@ def util_test_ten_crop(crop_size, vertical_flip=False, plot=False):
|
||||||
logger.info("dtype of image_2: {}".format(image_2.dtype))
|
logger.info("dtype of image_2: {}".format(image_2.dtype))
|
||||||
|
|
||||||
if plot:
|
if plot:
|
||||||
visualize_list(np.array([image_1]*10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1))
|
visualize_list(np.array([image_1] * 10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1))
|
||||||
|
|
||||||
# The output data should be of a 4D tensor shape, a stack of 10 images.
|
# The output data should be of a 4D tensor shape, a stack of 10 images.
|
||||||
assert len(image_2.shape) == 4
|
assert len(image_2.shape) == 4
|
||||||
|
@ -144,7 +144,7 @@ def test_ten_crop_invalid_size_error_msg():
|
||||||
vision.TenCrop(0),
|
vision.TenCrop(0),
|
||||||
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
|
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
|
||||||
]
|
]
|
||||||
error_msg = "Input is not within the required range"
|
error_msg = "Input is not within the required interval of (1 to 16777216)."
|
||||||
assert error_msg == str(info.value)
|
assert error_msg == str(info.value)
|
||||||
|
|
||||||
with pytest.raises(ValueError) as info:
|
with pytest.raises(ValueError) as info:
|
||||||
|
|
|
@ -169,7 +169,9 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "operations" in str(e)
|
assert "Argument tensor_op_5 with value" \
|
||||||
|
" <mindspore.dataset.transforms.vision.py_transforms.Invert" in str(e)
|
||||||
|
assert "is not of type (<class 'mindspore._c_dataengine.TensorOp'>,)" in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
|
def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
|
||||||
|
@ -209,7 +211,7 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "num_ops" in str(e)
|
assert "Input num_ops must be greater than 0" in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
|
def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
|
||||||
|
@ -229,7 +231,7 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
assert "integer" in str(e)
|
assert "Argument num_ops with value 2.5 is not of type (<class 'int'>,)" in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
|
def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
|
||||||
|
|
|
@ -314,14 +314,15 @@ def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=
|
||||||
if len(orig) != len(aug) or not orig:
|
if len(orig) != len(aug) or not orig:
|
||||||
return
|
return
|
||||||
|
|
||||||
batch_size = int(len(orig)/plot_rows) # creates batches of images to plot together
|
batch_size = int(len(orig) / plot_rows) # creates batches of images to plot together
|
||||||
split_point = batch_size * plot_rows
|
split_point = batch_size * plot_rows
|
||||||
|
|
||||||
orig, aug = np.array(orig), np.array(aug)
|
orig, aug = np.array(orig), np.array(aug)
|
||||||
|
|
||||||
if len(orig) > plot_rows:
|
if len(orig) > plot_rows:
|
||||||
# Create batches of required size and add remainder to last batch
|
# Create batches of required size and add remainder to last batch
|
||||||
orig = np.split(orig[:split_point], batch_size) + ([orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added
|
orig = np.split(orig[:split_point], batch_size) + (
|
||||||
|
[orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added
|
||||||
aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else [])
|
aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else [])
|
||||||
else:
|
else:
|
||||||
orig = [orig]
|
orig = [orig]
|
||||||
|
@ -336,7 +337,8 @@ def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=
|
||||||
|
|
||||||
for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])):
|
for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])):
|
||||||
cur_ix = base_ix + x
|
cur_ix = base_ix + x
|
||||||
(axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) # select plotting axes based on number of image rows on plot - else case when 1 row
|
# select plotting axes based on number of image rows on plot - else case when 1 row
|
||||||
|
(axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1])
|
||||||
|
|
||||||
axA.imshow(dataA["image"])
|
axA.imshow(dataA["image"])
|
||||||
add_bounding_boxes(axA, dataA[annot_name])
|
add_bounding_boxes(axA, dataA[annot_name])
|
||||||
|
|
Loading…
Reference in New Issue