dataset API docstring cleanup: Standard product terms NumPy, Python

This commit is contained in:
Cathy Wong 2020-08-31 20:34:25 -04:00
parent aef726534e
commit 4d4c11b133
23 changed files with 284 additions and 279 deletions

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""init file for python callback""" """init file for Python callback"""
from .ds_callback import DSCallback, WaitedDSCallback from .ds_callback import DSCallback, WaitedDSCallback
__all__ = ["DSCallback", "WaitedDSCallback"] __all__ = ["DSCallback", "WaitedDSCallback"]

View File

@ -33,8 +33,8 @@ def set_seed(seed):
Set the seed to be used in any random generator. This is used to produce deterministic results. Set the seed to be used in any random generator. This is used to produce deterministic results.
Note: Note:
This set_seed function sets the seed in the python random library and numpy.random library This set_seed function sets the seed in the Python random library and numpy.random library
for deterministic python augmentations using randomness. This set_seed function should for deterministic Python augmentations using randomness. This set_seed function should
be called with every iterator created to reset the random seed. In our pipeline this be called with every iterator created to reset the random seed. In our pipeline this
does not guarantee deterministic results with num_parallel_workers > 1. does not guarantee deterministic results with num_parallel_workers > 1.

View File

@ -369,6 +369,6 @@ def check_gnn_list_or_ndarray(param, param_name):
def check_tensor_op(param, param_name): def check_tensor_op(param, param_name):
"""check whether param is a tensor op or a callable python function""" """check whether param is a tensor op or a callable Python function"""
if not isinstance(param, cde.TensorOp) and not callable(param): if not isinstance(param, cde.TensorOp) and not callable(param):
raise TypeError("{0} is not a c_transform op (TensorOp) nor a callable pyfunc.".format(param_name)) raise TypeError("{0} is not a c_transform op (TensorOp) nor a callable pyfunc.".format(param_name))

View File

@ -434,8 +434,8 @@ class Dataset:
same). same).
num_parallel_workers (int, optional): Number of threads used to process the dataset in num_parallel_workers (int, optional): Number of threads used to process the dataset in
parallel (default=None, the value from the config will be used). parallel (default=None, the value from the config will be used).
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=False). option could be beneficial if the Python operation is computational heavy (default=False).
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended. The cache feature is under development and is not recommended.
callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None). callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None).
@ -565,7 +565,7 @@ class Dataset:
If input_columns not provided or empty, all columns will be used. If input_columns not provided or empty, all columns will be used.
Args: Args:
predicate(callable): python callable which returns a boolean value, if False then filter the element. predicate(callable): Python callable which returns a boolean value, if False then filter the element.
input_columns: (list[str], optional): List of names of the input columns, when input_columns: (list[str], optional): List of names of the input columns, when
default=None, the predicate will be applied on all columns in the dataset. default=None, the predicate will be applied on all columns in the dataset.
num_parallel_workers (int, optional): Number of workers to process the Dataset num_parallel_workers (int, optional): Number of workers to process the Dataset
@ -1541,7 +1541,7 @@ class MappableDataset(SourceDataset):
class DatasetOp(Dataset): class DatasetOp(Dataset):
""" """
Abstract class to represent a operations on dataset. Abstract class to represent an operation on a dataset.
""" """
# No need for __init__ since it is the same as the super's init # No need for __init__ since it is the same as the super's init
@ -1907,7 +1907,7 @@ _GLOBAL_PYFUNC_LIST = []
# Pyfunc worker init function # Pyfunc worker init function
# Python multiprocessing library forbid sending lambda function through pipe. # Python multiprocessing library forbid sending lambda function through pipe.
# This init function allow us to add all python function to a global collection and then fork afterwards. # This init function allow us to add all Python function to a global collection and then fork afterwards.
def _pyfunc_worker_init(pyfunc_list): def _pyfunc_worker_init(pyfunc_list):
global _GLOBAL_PYFUNC_LIST global _GLOBAL_PYFUNC_LIST
_GLOBAL_PYFUNC_LIST = pyfunc_list _GLOBAL_PYFUNC_LIST = pyfunc_list
@ -1925,11 +1925,11 @@ def _pyfunc_worker_exec(index, *args):
# PythonCallable wrapper for multiprocess pyfunc # PythonCallable wrapper for multiprocess pyfunc
class _PythonCallable: class _PythonCallable:
""" """
Internal python function wrapper for multiprocessing pyfunc. Internal Python function wrapper for multiprocessing pyfunc.
""" """
def __init__(self, py_callable, idx, pool=None): def __init__(self, py_callable, idx, pool=None):
# Original python callable from user. # Original Python callable from user.
self.py_callable = py_callable self.py_callable = py_callable
# Process pool created for current iterator. # Process pool created for current iterator.
self.pool = pool self.pool = pool
@ -1946,7 +1946,7 @@ class _PythonCallable:
self.pool.terminate() self.pool.terminate()
self.pool.join() self.pool.join()
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
# Invoke original python callable in master process in case the pool is gone. # Invoke original Python callable in master process in case the pool is gone.
return self.py_callable(*args) return self.py_callable(*args)
@ -1969,8 +1969,8 @@ class MapDataset(DatasetOp):
The argument is mandatory if len(input_columns) != len(output_columns). The argument is mandatory if len(input_columns) != len(output_columns).
num_parallel_workers (int, optional): Number of workers to process the Dataset num_parallel_workers (int, optional): Number of workers to process the Dataset
in parallel (default=None). in parallel (default=None).
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=False). option could be beneficial if the Python operation is computational heavy (default=False).
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended. The cache feature is under development and is not recommended.
callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None) callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None)
@ -2065,7 +2065,7 @@ class MapDataset(DatasetOp):
iter_specific_operations = [] iter_specific_operations = []
callable_list = [] callable_list = []
# Pass #1, look for python callables and build list # Pass #1, look for Python callables and build list
for op in self.operations: for op in self.operations:
if callable(op): if callable(op):
callable_list.append(op) callable_list.append(op)
@ -2080,7 +2080,7 @@ class MapDataset(DatasetOp):
idx = 0 idx = 0
for op in self.operations: for op in self.operations:
if callable(op): if callable(op):
# Wrap python callable into _PythonCallable # Wrap Python callable into _PythonCallable
iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool)) iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
idx += 1 idx += 1
else: else:
@ -2099,7 +2099,7 @@ class FilterDataset(DatasetOp):
Args: Args:
input_dataset: Input Dataset to be mapped. input_dataset: Input Dataset to be mapped.
predicate: python callable which returns a boolean value, if False then filter the element. predicate: Python callable which returns a boolean value, if False then filter the element.
input_columns: (list[str]): List of names of the input columns, when input_columns: (list[str]): List of names of the input columns, when
default=None, the predicate will be applied all columns in the dataset. default=None, the predicate will be applied all columns in the dataset.
num_parallel_workers (int, optional): Number of workers to process the Dataset num_parallel_workers (int, optional): Number of workers to process the Dataset
@ -3079,7 +3079,7 @@ def _generator_fn(generator, num_samples):
def _py_sampler_fn(sampler, num_samples, dataset): def _py_sampler_fn(sampler, num_samples, dataset):
""" """
Generator function wrapper for mappable dataset with python sampler. Generator function wrapper for mappable dataset with Python sampler.
""" """
if num_samples is not None: if num_samples is not None:
sampler_iter = iter(sampler) sampler_iter = iter(sampler)
@ -3120,7 +3120,7 @@ def _cpp_sampler_fn_mp(sampler, dataset, num_worker, multi_process):
def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process): def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process):
""" """
Multiprocessing generator function wrapper for mappable dataset with python sampler. Multiprocessing generator function wrapper for mappable dataset with Python sampler.
""" """
indices = _fetch_py_sampler_indices(sampler, num_samples) indices = _fetch_py_sampler_indices(sampler, num_samples)
sample_fn = SamplerFn(dataset, num_worker, multi_process) sample_fn = SamplerFn(dataset, num_worker, multi_process)
@ -3129,7 +3129,7 @@ def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process):
def _fetch_py_sampler_indices(sampler, num_samples): def _fetch_py_sampler_indices(sampler, num_samples):
""" """
Indice fetcher for python sampler. Indice fetcher for Python sampler.
""" """
if num_samples is not None: if num_samples is not None:
sampler_iter = iter(sampler) sampler_iter = iter(sampler)
@ -3316,7 +3316,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
class GeneratorDataset(MappableDataset): class GeneratorDataset(MappableDataset):
""" """
A source dataset that generates data from python by invoking python data source each epoch. A source dataset that generates data from Python by invoking Python data source each epoch.
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
below shows what input args are allowed and their expected behavior. below shows what input args are allowed and their expected behavior.
@ -3349,10 +3349,11 @@ class GeneratorDataset(MappableDataset):
Args: Args:
source (Union[Callable, Iterable, Random Accessible]): source (Union[Callable, Iterable, Random Accessible]):
A generator callable object, an iterable python object or a random accessible python object. A generator callable object, an iterable Python object or a random accessible Python object.
Callable source is required to return a tuple of numpy array as a row of the dataset on source().next(). Callable source is required to return a tuple of NumPy arrays as a row of the dataset on source().next().
Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next(). Iterable source is required to return a tuple of NumPy arrays as a row of the dataset on
Random accessible source is required to return a tuple of numpy array as a row of the dataset on iter(source).next().
Random accessible source is required to return a tuple of NumPy arrays as a row of the dataset on
source[idx]. source[idx].
column_names (list[str], optional): List of column names of the dataset (default=None). Users are required to column_names (list[str], optional): List of column names of the dataset (default=None). Users are required to
provide either column_names or schema. provide either column_names or schema.
@ -3371,8 +3372,8 @@ class GeneratorDataset(MappableDataset):
When this argument is specified, 'num_samples' will not effect. Random accessible input is required. When this argument is specified, 'num_samples' will not effect. Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required. when num_shards is also specified. Random accessible input is required.
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=True). option could be beneficial if the Python operation is computational heavy (default=True).
Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@ -4474,7 +4475,7 @@ class VOCDataset(MappableDataset):
argument should be specified only when num_shards is also specified. argument should be specified only when num_shards is also specified.
Raises: Raises:
RuntimeError: If xml of Annotations is a invalid format. RuntimeError: If xml of Annotations is an invalid format.
RuntimeError: If xml of Annotations loss attribution of "object". RuntimeError: If xml of Annotations loss attribution of "object".
RuntimeError: If xml of Annotations loss attribution of "bndbox". RuntimeError: If xml of Annotations loss attribution of "bndbox".
RuntimeError: If sampler and shuffle are specified at the same time. RuntimeError: If sampler and shuffle are specified at the same time.
@ -5322,7 +5323,7 @@ class TextFileDataset(SourceDataset):
class _NumpySlicesDataset: class _NumpySlicesDataset:
""" """
Mainly for dealing with several kinds of format of python data, and return one row each time. Mainly for dealing with several kinds of format of Python data, and return one row each time.
""" """
def __init__(self, data, column_list=None): def __init__(self, data, column_list=None):
@ -5388,7 +5389,7 @@ class _NumpySlicesDataset:
class NumpySlicesDataset(GeneratorDataset): class NumpySlicesDataset(GeneratorDataset):
""" """
Create a dataset with given data slices, mainly for loading python data into dataset. Create a dataset with given data slices, mainly for loading Python data into dataset.
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
below shows what input args are allowed and their expected behavior. below shows what input args are allowed and their expected behavior.
@ -5421,7 +5422,7 @@ class NumpySlicesDataset(GeneratorDataset):
Args: Args:
data (Union[list, tuple, dict]) Input of Given data, supported data type includes list, tuple, dict and other data (Union[list, tuple, dict]) Input of Given data, supported data type includes list, tuple, dict and other
numpy format. Input data will be sliced in first dimension and generate many rows, large data is not NumPy format. Input data will be sliced in first dimension and generate many rows, large data is not
recommend to load in this way as data is loading into memory. recommend to load in this way as data is loading into memory.
column_names (list[str], optional): List of column names of the dataset (default=None). If column_names not column_names (list[str], optional): List of column names of the dataset (default=None). If column_names not
provided, when data is dict, column_names will be its key, otherwise it will be like column_1, column_2 ... provided, when data is dict, column_names will be its key, otherwise it will be like column_1, column_2 ...
@ -5444,7 +5445,7 @@ class NumpySlicesDataset(GeneratorDataset):
>>> # 2) Input data can be a dict, and column_names will be its key >>> # 2) Input data can be a dict, and column_names will be its key
>>> data = {"a": [1, 2], "b": [3, 4]} >>> data = {"a": [1, 2], "b": [3, 4]}
>>> dataset2 = ds.NumpySlicesDataset(data) >>> dataset2 = ds.NumpySlicesDataset(data)
>>> # 3) Input data can be a tuple of lists (or numpy arrays), each tuple element refers to data in each column >>> # 3) Input data can be a tuple of lists (or NumPy arrays), each tuple element refers to data in each column
>>> data = ([1, 2], [3, 4], [5, 6]) >>> data = ([1, 2], [3, 4], [5, 6])
>>> dataset3 = ds.NumpySlicesDataset(data, column_names=["column_1", "column_2", "column_3"]) >>> dataset3 = ds.NumpySlicesDataset(data, column_names=["column_1", "column_2", "column_3"])
>>> # 4) Load data from csv file >>> # 4) Load data from csv file

View File

@ -38,7 +38,7 @@ def _cleanup():
def alter_tree(node): def alter_tree(node):
"""Traversing the python Dataset tree/graph to perform some alteration to some specific nodes.""" """Traversing the Python dataset tree/graph to perform some alteration to some specific nodes."""
if not node.children: if not node.children:
return _alter_node(node) return _alter_node(node)
@ -98,9 +98,9 @@ class Iterator:
def stop(self): def stop(self):
""" """
Manually terminate python iterator instead of relying on out of scope destruction. Manually terminate Python iterator instead of relying on out of scope destruction.
""" """
logger.info("terminating python iterator. This will also terminate c++ pipeline.") logger.info("terminating Python iterator. This will also terminate c++ pipeline.")
if hasattr(self, 'depipeline') and self.depipeline: if hasattr(self, 'depipeline') and self.depipeline:
del self.depipeline del self.depipeline
@ -193,7 +193,7 @@ class Iterator:
return op_type return op_type
# Convert python node into C node and add to C layer execution tree in postorder traversal. # Convert Python node into C node and add to C layer execution tree in postorder traversal.
def __convert_node_postorder(self, node): def __convert_node_postorder(self, node):
self.check_node_type(node) self.check_node_type(node)
op_type = self.__get_dataset_type(node) op_type = self.__get_dataset_type(node)

View File

@ -48,7 +48,7 @@ def serialize(dataset, json_filepath=None):
>>> data = data.batch(batch_size=10, drop_remainder=True) >>> data = data.batch(batch_size=10, drop_remainder=True)
>>> >>>
>>> ds.engine.serialize(data, json_filepath="mnist_dataset_pipeline.json") # serialize it to json file >>> ds.engine.serialize(data, json_filepath="mnist_dataset_pipeline.json") # serialize it to json file
>>> serialized_data = ds.engine.serialize(data) # serialize it to python dict >>> serialized_data = ds.engine.serialize(data) # serialize it to Python dict
""" """
serialized_pipeline = traverse(dataset) serialized_pipeline = traverse(dataset)
if json_filepath: if json_filepath:
@ -62,7 +62,7 @@ def deserialize(input_dict=None, json_filepath=None):
Construct a de pipeline from a json file produced by de.serialize(). Construct a de pipeline from a json file produced by de.serialize().
Args: Args:
input_dict (dict): a python dictionary containing a serialized dataset graph input_dict (dict): a Python dictionary containing a serialized dataset graph
json_filepath (str): a path to the json file. json_filepath (str): a path to the json file.
Returns: Returns:
@ -83,7 +83,7 @@ def deserialize(input_dict=None, json_filepath=None):
>>> # Use case 1: to/from json file >>> # Use case 1: to/from json file
>>> ds.engine.serialize(data, json_filepath="mnist_dataset_pipeline.json") >>> ds.engine.serialize(data, json_filepath="mnist_dataset_pipeline.json")
>>> data = ds.engine.deserialize(json_filepath="mnist_dataset_pipeline.json") >>> data = ds.engine.deserialize(json_filepath="mnist_dataset_pipeline.json")
>>> # Use case 2: to/from python dictionary >>> # Use case 2: to/from Python dictionary
>>> serialized_data = ds.engine.serialize(data) >>> serialized_data = ds.engine.serialize(data)
>>> data = ds.engine.deserialize(input_dict=serialized_data) >>> data = ds.engine.deserialize(input_dict=serialized_data)
@ -110,12 +110,12 @@ def expand_path(node_repr, key, val):
def serialize_operations(node_repr, key, val): def serialize_operations(node_repr, key, val):
"""Serialize tensor op (python object) to dictionary.""" """Serialize tensor op (Python object) to dictionary."""
if isinstance(val, list): if isinstance(val, list):
node_repr[key] = [] node_repr[key] = []
for op in val: for op in val:
node_repr[key].append(op.__dict__) node_repr[key].append(op.__dict__)
# Extracting module and name information from a python object # Extracting module and name information from a Python object
# Example: tensor_op_module is 'minddata.transforms.c_transforms' and tensor_op_name is 'Decode' # Example: tensor_op_module is 'minddata.transforms.c_transforms' and tensor_op_name is 'Decode'
node_repr[key][-1]['tensor_op_name'] = type(op).__name__ node_repr[key][-1]['tensor_op_name'] = type(op).__name__
node_repr[key][-1]['tensor_op_module'] = type(op).__module__ node_repr[key][-1]['tensor_op_module'] = type(op).__module__
@ -137,7 +137,7 @@ def serialize_sampler(node_repr, val):
def traverse(node): def traverse(node):
"""Pre-order traverse the pipeline and capture the information as we go.""" """Pre-order traverse the pipeline and capture the information as we go."""
# Node representation (node_repr) is a python dictionary that capture and store the # Node representation (node_repr) is a Python dictionary that capture and store the
# dataset pipeline information before dumping it to JSON or other format. # dataset pipeline information before dumping it to JSON or other format.
node_repr = dict() node_repr = dict()
node_repr['op_type'] = type(node).__name__ node_repr['op_type'] = type(node).__name__
@ -222,12 +222,12 @@ def compare(pipeline1, pipeline2):
def construct_pipeline(node): def construct_pipeline(node):
"""Construct the python Dataset objects by following the dictionary deserialized from json file.""" """Construct the Python Dataset objects by following the dictionary deserialized from json file."""
op_type = node.get('op_type') op_type = node.get('op_type')
if not op_type: if not op_type:
raise ValueError("op_type field in the json file can't be None.") raise ValueError("op_type field in the json file can't be None.")
# Instantiate python Dataset object based on the current dictionary element # Instantiate Python Dataset object based on the current dictionary element
dataset = create_node(node) dataset = create_node(node)
# Initially it is not connected to any other object. # Initially it is not connected to any other object.
dataset.children = [] dataset.children = []
@ -240,12 +240,12 @@ def construct_pipeline(node):
def create_node(node): def create_node(node):
"""Parse the key, value in the node dictionary and instantiate the python Dataset object""" """Parse the key, value in the node dictionary and instantiate the Python Dataset object"""
logger.info('creating node: %s', node['op_type']) logger.info('creating node: %s', node['op_type'])
dataset_op = node['op_type'] dataset_op = node['op_type']
op_module = node['op_module'] op_module = node['op_module']
# Get the python class to be instantiated. # Get the Python class to be instantiated.
# Example: # Example:
# "op_type": "MapDataset", # "op_type": "MapDataset",
# "op_module": "mindspore.dataset.datasets", # "op_module": "mindspore.dataset.datasets",

View File

@ -589,7 +589,7 @@ def check_filter(method):
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
[predicate, input_columns, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs) [predicate, input_columns, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs)
if not callable(predicate): if not callable(predicate):
raise TypeError("Predicate should be a python function or a callable python object.") raise TypeError("Predicate should be a Python function or a callable Python object.")
check_num_parallel_workers(num_parallel_workers) check_num_parallel_workers(num_parallel_workers)

View File

@ -484,7 +484,7 @@ if platform.system().lower() != 'windows':
The original string will be split by matched elements. The original string will be split by matched elements.
keep_delim_pattern(str, optional): The string matched by 'delim_pattern' can be kept as a token keep_delim_pattern(str, optional): The string matched by 'delim_pattern' can be kept as a token
if it can be matched by 'keep_delim_pattern'. And the default value is empty str(''), if it can be matched by 'keep_delim_pattern'. And the default value is empty str(''),
in this situation, delimiters will not kept as a output token(default=''). in this situation, delimiters will not kept as an output token(default='').
with_offsets (bool, optional): If or not output offsets of tokens (default=False). with_offsets (bool, optional): If or not output offsets of tokens (default=False).
Examples: Examples:

View File

@ -213,36 +213,36 @@ class SentencePieceVocab(cde.SentencePieceVocab):
def to_str(array, encoding='utf8'): def to_str(array, encoding='utf8'):
""" """
Convert numpy array of `bytes` to array of `str` by decoding each element based on charset `encoding`. Convert NumPy array of `bytes` to array of `str` by decoding each element based on charset `encoding`.
Args: Args:
array (numpy.ndarray): Array of type `bytes` representing strings. array (numpy.ndarray): Array of type `bytes` representing strings.
encoding (str): Indicating the charset for decoding. encoding (str): Indicating the charset for decoding.
Returns: Returns:
numpy.ndarray, numpy array of `str`. numpy.ndarray, NumPy array of `str`.
""" """
if not isinstance(array, np.ndarray): if not isinstance(array, np.ndarray):
raise ValueError('input should be a numpy array.') raise ValueError('input should be a NumPy array.')
return np.char.decode(array, encoding) return np.char.decode(array, encoding)
def to_bytes(array, encoding='utf8'): def to_bytes(array, encoding='utf8'):
""" """
Convert numpy array of `str` to array of `bytes` by encoding each element based on charset `encoding`. Convert NumPy array of `str` to array of `bytes` by encoding each element based on charset `encoding`.
Args: Args:
array (numpy.ndarray): Array of type `str` representing strings. array (numpy.ndarray): Array of type `str` representing strings.
encoding (str): Indicating the charset for encoding. encoding (str): Indicating the charset for encoding.
Returns: Returns:
numpy.ndarray, numpy array of `bytes`. numpy.ndarray, NumPy array of `bytes`.
""" """
if not isinstance(array, np.ndarray): if not isinstance(array, np.ndarray):
raise ValueError('input should be a numpy array.') raise ValueError('input should be a NumPy array.')
return np.char.encode(array, encoding) return np.char.encode(array, encoding)

View File

@ -414,7 +414,7 @@ def check_python_tokenizer(method):
[tokenizer], _ = parse_user_args(method, *args, **kwargs) [tokenizer], _ = parse_user_args(method, *args, **kwargs)
if not callable(tokenizer): if not callable(tokenizer):
raise TypeError("tokenizer is not a callable python function") raise TypeError("tokenizer is not a callable Python function")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)

View File

@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
""" """
This module is to support common augmentations. C_transforms is a high performance This module is to support common augmentations. C_transforms is a high performance
image augmentation module which is developed with c++ opencv. Py_transforms image augmentation module which is developed with C++ OpenCV. Py_transforms
provide more kinds of image augmentations which is developed with python PIL. provide more kinds of image augmentations which is developed with Python PIL.
""" """
from . import vision from . import vision
from . import c_transforms from . import c_transforms

View File

@ -89,8 +89,8 @@ class Slice(cde.SliceOp):
1. :py:obj:`int`: Slice this index only. Negative index is supported. 1. :py:obj:`int`: Slice this index only. Negative index is supported.
2. :py:obj:`list(int)`: Slice these indices ion the list only. Negative indices are supported. 2. :py:obj:`list(int)`: Slice these indices ion the list only. Negative indices are supported.
3. :py:obj:`slice`: Slice the generated indices from the slice object. Similar to `start:stop:step`. 3. :py:obj:`slice`: Slice the generated indices from the slice object. Similar to `start:stop:step`.
4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in python indexing. 4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in Python indexing.
5. :py:obj:`Ellipses`: Slice all dimensions between the two slices. Similar to `...` in python indexing. 5. :py:obj:`Ellipses`: Slice all dimensions between the two slices. Similar to `...` in Python indexing.
Examples: Examples:
>>> # Data before >>> # Data before
@ -206,8 +206,8 @@ class Concatenate(cde.ConcatenateOp):
Args: Args:
axis (int, optional): concatenate the tensors along given axis (Default=0). axis (int, optional): concatenate the tensors along given axis (Default=0).
prepend (numpy.array, optional): numpy array to be prepended to the already concatenated tensors (Default=None). prepend (numpy.array, optional): NumPy array to be prepended to the already concatenated tensors (Default=None).
append (numpy.array, optional): numpy array to be appended to the already concatenated tensors (Default=None). append (numpy.array, optional): NumPy array to be appended to the already concatenated tensors (Default=None).
""" """
@check_concat_type @check_concat_type

View File

@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
""" """
This module py_transforms is implemented basing on python. It provides common This module py_transforms is implemented basing on Python. It provides common
operations including OneHotOp. operations including OneHotOp.
""" """

View File

@ -15,7 +15,7 @@
This module is to support vision augmentations. It includes two parts: This module is to support vision augmentations. It includes two parts:
c_transforms and py_transforms. C_transforms is a high performance c_transforms and py_transforms. C_transforms is a high performance
image augmentation module which is developed with c++ opencv. Py_transforms image augmentation module which is developed with c++ opencv. Py_transforms
provide more kinds of image augmentations which is developed with python PIL. provide more kinds of image augmentations which is developed with Python PIL.
""" """
from . import c_transforms from . import c_transforms
from . import py_transforms from . import py_transforms

View File

@ -175,7 +175,7 @@ class CutMixBatch(cde.CutMixBatchOp):
class CutOut(cde.CutOutOp): class CutOut(cde.CutOutOp):
""" """
Randomly cut (mask) out a given number of square patches from the input Numpy image array. Randomly cut (mask) out a given number of square patches from the input NumPy image array.
Args: Args:
length (int): The side length of each square patch. length (int): The side length of each square patch.
@ -935,7 +935,7 @@ class UniformAugment(cde.UniformAugOp):
Tensor operation to perform randomly selected augmentation. Tensor operation to perform randomly selected augmentation.
Args: Args:
transforms: list of C++ operations (python OPs are not accepted). transforms: list of C++ operations (Python OPs are not accepted).
num_ops (int, optional): number of OPs to be selected and applied (default=2). num_ops (int, optional): number of OPs to be selected and applied (default=2).
Examples: Examples:

View File

@ -14,11 +14,11 @@
# ============================================================================== # ==============================================================================
""" """
The module vision.py_transforms is implemented basing on python The module vision.py_transforms is implemented based on Python PIL.
PIL. This module provides many kinds of image augmentations. It also provides This module provides many kinds of image augmentations. It also provides
transferring methods between PIL Image and numpy array. For users who prefer transferring methods between PIL image and NumPy array. For users who prefer
python PIL in image learning task, this module is a good tool to process image Python PIL in image learning task, this module is a good tool to process image
augmentations. Users could also self-define their own augmentations with python augmentations. Users can also self-define their own augmentations with Python
PIL. PIL.
""" """
import numbers import numbers
@ -85,22 +85,22 @@ class ComposeOp:
Call method. Call method.
Returns: Returns:
lambda function, Lambda function that takes in an img to apply transformations on. lambda function, Lambda function that takes in an image to apply transformations on.
""" """
return lambda img: util.compose(img, self.transforms) return lambda img: util.compose(img, self.transforms)
class ToTensor: class ToTensor:
""" """
Convert the input Numpy image array or PIL image of shape (H,W,C) to a Numpy ndarray of shape (C,H,W). Convert the input NumPy image array or PIL image of shape (H,W,C) to a NumPy ndarray of shape (C,H,W).
Note: Note:
The ranges of values in height and width dimension changes from [0, 255] to [0.0, 1.0]. Type cast to output_type The ranges of values in height and width dimension changes from [0, 255] to [0.0, 1.0]. Type cast to output_type
(default Numpy float 32). (default NumPy float 32).
The range of channel dimension remains the same. The range of channel dimension remains the same.
Args: Args:
output_type (numpy datatype, optional): The datatype of the numpy output (default=np.float32). output_type (Numpy datatype, optional): The datatype of the NumPy output (default=np.float32).
Examples: Examples:
>>> py_transforms.ComposeOp([py_transforms.Decode(), >>> py_transforms.ComposeOp([py_transforms.Decode(),
@ -116,7 +116,7 @@ class ToTensor:
Call method. Call method.
Args: Args:
img (PIL Image): PIL Image to be converted to numpy.ndarray. img (PIL image): PIL image to be converted to numpy.ndarray.
Returns: Returns:
img (numpy.ndarray), Converted image. img (numpy.ndarray), Converted image.
@ -126,10 +126,10 @@ class ToTensor:
class ToType: class ToType:
""" """
Convert the input Numpy image array to desired numpy dtype. Convert the input NumPy image array to desired NumPy dtype.
Args: Args:
output_type (numpy datatype): The datatype of the numpy output, e.g. numpy.float32. output_type (Numpy datatype): The datatype of the NumPy output, e.g. numpy.float32.
Examples: Examples:
>>> import numpy as np >>> import numpy as np
@ -147,7 +147,7 @@ class ToType:
Call method. Call method.
Args: Args:
numpy object : numpy object to be type swapped. NumPy object : NumPy object to be type swapped.
Returns: Returns:
img (numpy.ndarray), Converted image. img (numpy.ndarray), Converted image.
@ -157,7 +157,7 @@ class ToType:
class HWC2CHW: class HWC2CHW:
""" """
Transpose a Numpy image array; shape (H, W, C) to shape (C, H, W). Transpose a NumPy image array; shape (H, W, C) to shape (C, H, W).
""" """
def __call__(self, img): def __call__(self, img):
@ -175,10 +175,10 @@ class HWC2CHW:
class ToPIL: class ToPIL:
""" """
Convert the input decoded Numpy image array of RGB mode to a PIL Image of RGB mode. Convert the input decoded NumPy image array of RGB mode to a PIL image of RGB mode.
Examples: Examples:
>>> # data is already decoded, but not in PIL Image format >>> # data is already decoded, but not in PIL image format
>>> py_transforms.ComposeOp([py_transforms.ToPIL(), >>> py_transforms.ComposeOp([py_transforms.ToPIL(),
>>> py_transforms.RandomHorizontalFlip(0.5), >>> py_transforms.RandomHorizontalFlip(0.5),
>>> py_transforms.ToTensor()]) >>> py_transforms.ToTensor()])
@ -189,17 +189,17 @@ class ToPIL:
Call method. Call method.
Args: Args:
img (numpy.ndarray): Decoded image array, of RGB mode, to be converted to PIL Image. img (numpy.ndarray): Decoded image array, of RGB mode, to be converted to PIL image.
Returns: Returns:
img (PIL Image), Image converted to PIL Image of RGB mode. img (PIL image), Image converted to PIL image of RGB mode.
""" """
return util.to_pil(img) return util.to_pil(img)
class Decode: class Decode:
""" """
Decode the input image to PIL Image format in RGB mode. Decode the input image to PIL image format in RGB mode.
Examples: Examples:
>>> py_transforms.ComposeOp([py_transforms.Decode(), >>> py_transforms.ComposeOp([py_transforms.Decode(),
@ -215,14 +215,14 @@ class Decode:
img (Bytes-like Objects):Image to be decoded. img (Bytes-like Objects):Image to be decoded.
Returns: Returns:
img (PIL Image), Decoded image in RGB mode. img (PIL image), Decoded image in RGB mode.
""" """
return util.decode(img) return util.decode(img)
class Normalize: class Normalize:
""" """
Normalize the input Numpy image array of shape (C, H, W) with the given mean and standard deviation. Normalize the input NumPy image array of shape (C, H, W) with the given mean and standard deviation.
The values of the array need to be in range (0.0, 1.0]. The values of the array need to be in range (0.0, 1.0].
@ -257,7 +257,7 @@ class Normalize:
class RandomCrop: class RandomCrop:
""" """
Crop the input PIL Image at a random location. Crop the input PIL image at a random location.
Args: Args:
size (Union[int, sequence]): The output size of the cropped image. size (Union[int, sequence]): The output size of the cropped image.
@ -311,10 +311,10 @@ class RandomCrop:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be randomly cropped. img (PIL image): Image to be randomly cropped.
Returns: Returns:
PIL Image, Cropped image. PIL image, Cropped image.
""" """
return util.random_crop(img, self.size, self.padding, self.pad_if_needed, return util.random_crop(img, self.size, self.padding, self.pad_if_needed,
self.fill_value, self.padding_mode) self.fill_value, self.padding_mode)
@ -342,10 +342,10 @@ class RandomHorizontalFlip:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be flipped horizontally. img (PIL image): Image to be flipped horizontally.
Returns: Returns:
img (PIL Image), Randomly flipped image. img (PIL image), Randomly flipped image.
""" """
return util.random_horizontal_flip(img, self.prob) return util.random_horizontal_flip(img, self.prob)
@ -372,17 +372,17 @@ class RandomVerticalFlip:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be flipped vertically. img (PIL image): Image to be flipped vertically.
Returns: Returns:
img (PIL Image), Randomly flipped image. img (PIL image), Randomly flipped image.
""" """
return util.random_vertical_flip(img, self.prob) return util.random_vertical_flip(img, self.prob)
class Resize: class Resize:
""" """
Resize the input PIL Image to the given size. Resize the input PIL image to the given size.
Args: Args:
size (Union[int, sequence]): The output size of the resized image. size (Union[int, sequence]): The output size of the resized image.
@ -414,10 +414,10 @@ class Resize:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be resized. img (PIL image): Image to be resized.
Returns: Returns:
img (PIL Image), Resize image. img (PIL image), Resize image.
""" """
return util.resize(img, self.size, self.interpolation) return util.resize(img, self.size, self.interpolation)
@ -465,10 +465,10 @@ class RandomResizedCrop:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be randomly cropped and resized. img (PIL image): Image to be randomly cropped and resized.
Returns: Returns:
img (PIL Image), Randomly cropped and resized image. img (PIL image), Randomly cropped and resized image.
""" """
return util.random_resize_crop(img, self.size, self.scale, self.ratio, return util.random_resize_crop(img, self.size, self.scale, self.ratio,
self.interpolation, self.max_attempts) self.interpolation, self.max_attempts)
@ -476,7 +476,7 @@ class RandomResizedCrop:
class CenterCrop: class CenterCrop:
""" """
Crop the central reigion of the input PIL Image to the given size. Crop the central reigion of the input PIL image to the given size.
Args: Args:
size (Union[int, sequence]): The output size of the cropped image. size (Union[int, sequence]): The output size of the cropped image.
@ -498,10 +498,10 @@ class CenterCrop:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be center cropped. img (PIL image): Image to be center cropped.
Returns: Returns:
img (PIL Image), Cropped image. img (PIL image), Cropped image.
""" """
return util.center_crop(img, self.size) return util.center_crop(img, self.size)
@ -542,10 +542,10 @@ class RandomColorAdjust:
Call method. Call method.
Args: Args:
img (PIL Image): Image to have its color adjusted randomly. img (PIL image): Image to have its color adjusted randomly.
Returns: Returns:
img (PIL Image), Image after random adjustment of its color. img (PIL image), Image after random adjustment of its color.
""" """
return util.random_color_adjust(img, self.brightness, self.contrast, self.saturation, self.hue) return util.random_color_adjust(img, self.brightness, self.contrast, self.saturation, self.hue)
@ -601,10 +601,10 @@ class RandomRotation:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be rotated. img (PIL image): Image to be rotated.
Returns: Returns:
img (PIL Image), Rotated image. img (PIL image), Rotated image.
""" """
return util.random_rotation(img, self.degrees, self.resample, self.expand, self.center, self.fill_value) return util.random_rotation(img, self.degrees, self.resample, self.expand, self.center, self.fill_value)
@ -632,10 +632,10 @@ class RandomOrder:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be applied transformations in a random order. img (PIL image): Image to be applied transformations in a random order.
Returns: Returns:
img (PIL Image), Transformed image. img (PIL image), Transformed image.
""" """
return util.random_order(img, self.transforms) return util.random_order(img, self.transforms)
@ -665,10 +665,10 @@ class RandomApply:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be randomly applied a list transformations. img (PIL image): Image to be randomly applied a list transformations.
Returns: Returns:
img (PIL Image), Transformed image. img (PIL image), Transformed image.
""" """
return util.random_apply(img, self.transforms, self.prob) return util.random_apply(img, self.transforms, self.prob)
@ -696,10 +696,10 @@ class RandomChoice:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be applied transformation. img (PIL image): Image to be applied transformation.
Returns: Returns:
img (PIL Image), Transformed image. img (PIL image), Transformed image.
""" """
return util.random_choice(img, self.transforms) return util.random_choice(img, self.transforms)
@ -729,7 +729,7 @@ class FiveCrop:
Call method. Call method.
Args: Args:
img (PIL Image): PIL Image to be cropped. img (PIL image): PIL image to be cropped.
Returns: Returns:
img_tuple (tuple), a tuple of 5 PIL images img_tuple (tuple), a tuple of 5 PIL images
@ -768,7 +768,7 @@ class TenCrop:
Call method. Call method.
Args: Args:
img (PIL Image): PIL Image to be cropped. img (PIL image): PIL image to be cropped.
Returns: Returns:
img_tuple (tuple), a tuple of 10 PIL images img_tuple (tuple), a tuple of 10 PIL images
@ -801,10 +801,10 @@ class Grayscale:
Call method. Call method.
Args: Args:
img (PIL Image): PIL image to be converted to grayscale. img (PIL image): PIL image to be converted to grayscale.
Returns: Returns:
img (PIL Image), grayscaled image. img (PIL image), grayscaled image.
""" """
return util.grayscale(img, num_output_channels=self.num_output_channels) return util.grayscale(img, num_output_channels=self.num_output_channels)
@ -831,10 +831,10 @@ class RandomGrayscale:
Call method. Call method.
Args: Args:
img (PIL Image): PIL image to be converted to grayscale randomly. img (PIL image): PIL image to be converted to grayscale randomly.
Returns: Returns:
img (PIL Image), Randomly grayscale image, same number of channels as input image. img (PIL image), Randomly grayscale image, same number of channels as input image.
If input image has 1 channel, the output grayscale image is 1 channel. If input image has 1 channel, the output grayscale image is 1 channel.
If input image has 3 channels, the output image has 3 identical grayscale channels. If input image has 3 channels, the output image has 3 identical grayscale channels.
""" """
@ -895,17 +895,17 @@ class Pad:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be padded. img (PIL image): Image to be padded.
Returns: Returns:
img (PIL Image), Padded image. img (PIL image), Padded image.
""" """
return util.pad(img, self.padding, self.fill_value, self.padding_mode) return util.pad(img, self.padding, self.fill_value, self.padding_mode)
class RandomPerspective: class RandomPerspective:
""" """
Randomly apply perspective transformation to the input PIL Image with a given probability. Randomly apply perspective transformation to the input PIL image with a given probability.
Args: Args:
distortion_scale (float, optional): The scale of distortion, float between 0 and 1 (default=0.5). distortion_scale (float, optional): The scale of distortion, float between 0 and 1 (default=0.5).
@ -936,10 +936,10 @@ class RandomPerspective:
Call method. Call method.
Args: Args:
img (PIL Image): PIL Image to be applied perspective transformation randomly. img (PIL image): PIL image to be applied perspective transformation randomly.
Returns: Returns:
img (PIL Image), Image after being perspectively transformed randomly. img (PIL image), Image after being perspectively transformed randomly.
""" """
if self.prob > random.random(): if self.prob > random.random():
start_points, end_points = util.get_perspective_params(img, self.distortion_scale) start_points, end_points = util.get_perspective_params(img, self.distortion_scale)
@ -951,7 +951,7 @@ class RandomErasing:
""" """
Erase the pixels, within a selected rectangle region, to the given value. Erase the pixels, within a selected rectangle region, to the given value.
Randomly applied on the input Numpy image array with a given probability. Randomly applied on the input NumPy image array with a given probability.
Zhun Zhong et al. 'Random Erasing Data Augmentation' 2017 See https://arxiv.org/pdf/1708.04896.pdf Zhun Zhong et al. 'Random Erasing Data Augmentation' 2017 See https://arxiv.org/pdf/1708.04896.pdf
@ -989,10 +989,10 @@ class RandomErasing:
Call method. Call method.
Args: Args:
np_img (numpy.ndarray): Numpy image array of shape (C, H, W) to be randomly erased. np_img (numpy.ndarray): NumPy image array of shape (C, H, W) to be randomly erased.
Returns: Returns:
np_img (numpy.ndarray), Erased Numpy image array. np_img (numpy.ndarray), Erased NumPy image array.
""" """
bounded = True bounded = True
if self.prob > random.random(): if self.prob > random.random():
@ -1004,7 +1004,7 @@ class RandomErasing:
class Cutout: class Cutout:
""" """
Randomly cut (mask) out a given number of square patches from the input Numpy image array. Randomly cut (mask) out a given number of square patches from the input NumPy image array.
Terrance DeVries and Graham W. Taylor 'Improved Regularization of Convolutional Neural Networks with Cutout' 2017 Terrance DeVries and Graham W. Taylor 'Improved Regularization of Convolutional Neural Networks with Cutout' 2017
See https://arxiv.org/pdf/1708.04552.pdf See https://arxiv.org/pdf/1708.04552.pdf
@ -1029,13 +1029,13 @@ class Cutout:
Call method. Call method.
Args: Args:
np_img (numpy.ndarray): Numpy image array of shape (C, H, W) to be cut out. np_img (numpy.ndarray): NumPy image array of shape (C, H, W) to be cut out.
Returns: Returns:
np_img (numpy.ndarray), Numpy image array with square patches cut out. np_img (numpy.ndarray), NumPy image array with square patches cut out.
""" """
if not isinstance(np_img, np.ndarray): if not isinstance(np_img, np.ndarray):
raise TypeError('img should be Numpy array. Got {}'.format(type(np_img))) raise TypeError('img should be NumPy array. Got {}'.format(type(np_img)))
_, image_h, image_w = np_img.shape _, image_h, image_w = np_img.shape
scale = (self.length * self.length) / (image_h * image_w) scale = (self.length * self.length) / (image_h * image_w)
bounded = False bounded = False
@ -1048,7 +1048,7 @@ class Cutout:
class LinearTransformation: class LinearTransformation:
""" """
Apply linear transformation to the input Numpy image array, given a square transformation matrix and Apply linear transformation to the input NumPy image array, given a square transformation matrix and
a mean_vector. a mean_vector.
The transformation first flattens the input array and subtract mean_vector from it, then computes the The transformation first flattens the input array and subtract mean_vector from it, then computes the
@ -1056,7 +1056,7 @@ class LinearTransformation:
Args: Args:
transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), D = C x H x W. transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), D = C x H x W.
mean_vector (numpy.ndarray): a numpy ndarray of shape (D,) where D = C x H x W. mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where D = C x H x W.
Examples: Examples:
>>> py_transforms.ComposeOp([py_transforms.Decode(), >>> py_transforms.ComposeOp([py_transforms.Decode(),
@ -1075,7 +1075,7 @@ class LinearTransformation:
Call method. Call method.
Args: Args:
np_img (numpy.ndarray): Numpy image array of shape (C, H, W) to be linear transformed. np_img (numpy.ndarray): NumPy image array of shape (C, H, W) to be linear transformed.
Returns: Returns:
np_img (numpy.ndarray), Linear transformed image. np_img (numpy.ndarray), Linear transformed image.
@ -1164,10 +1164,10 @@ class RandomAffine:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be applied affine transformation. img (PIL image): Image to be applied affine transformation.
Returns: Returns:
img (PIL Image), Randomly affine transformed image. img (PIL image), Randomly affine transformed image.
""" """
return util.random_affine(img, return util.random_affine(img,
@ -1203,12 +1203,12 @@ class MixUp:
Call method. Call method.
Args: Args:
image (numpy.ndarray): numpy Image to be applied mix up transformation. image (numpy.ndarray): NumPy image to be applied mix up transformation.
label(numpy.ndarray): numpy label to be applied mix up transformation. label(numpy.ndarray): NumPy label to be applied mix up transformation.
Returns: Returns:
image (numpy.ndarray): numpy Image after being applied mix up transformation. image (numpy.ndarray): NumPy image after being applied mix up transformation.
label(numpy.ndarray): numpy label after being applied mix up transformation. label(numpy.ndarray): NumPy label after being applied mix up transformation.
""" """
if self.is_single: if self.is_single:
return util.mix_up_single(self.batch_size, image, label, self.alpha) return util.mix_up_single(self.batch_size, image, label, self.alpha)
@ -1217,7 +1217,7 @@ class MixUp:
class RgbToHsv: class RgbToHsv:
""" """
Convert a Numpy RGB image or one batch Numpy RGB images to HSV images. Convert a NumPy RGB image or one batch NumPy RGB images to HSV images.
Args: Args:
is_hwc (bool): The flag of image shape, (H, W, C) or (N, H, W, C) if True is_hwc (bool): The flag of image shape, (H, W, C) or (N, H, W, C) if True
@ -1232,18 +1232,18 @@ class RgbToHsv:
Call method. Call method.
Args: Args:
rgb_imgs (numpy.ndarray): Numpy RGB images array of shape (H, W, C) or (N, H, W, C), rgb_imgs (numpy.ndarray): NumPy RGB images array of shape (H, W, C) or (N, H, W, C),
or (C, H, W) or (N, C, H, W) to be converted. or (C, H, W) or (N, C, H, W) to be converted.
Returns: Returns:
np_hsv_img (numpy.ndarray), Numpy HSV images with same shape of rgb_imgs. np_hsv_img (numpy.ndarray), NumPy HSV images with same shape of rgb_imgs.
""" """
return util.rgb_to_hsvs(rgb_imgs, self.is_hwc) return util.rgb_to_hsvs(rgb_imgs, self.is_hwc)
class HsvToRgb: class HsvToRgb:
""" """
Convert a Numpy HSV image or one batch Numpy HSV images to RGB images. Convert a NumPy HSV image or one batch NumPy HSV images to RGB images.
Args: Args:
is_hwc (bool): The flag of image shape, (H, W, C) or (N, H, W, C) if True is_hwc (bool): The flag of image shape, (H, W, C) or (N, H, W, C) if True
@ -1258,11 +1258,11 @@ class HsvToRgb:
Call method. Call method.
Args: Args:
hsv_imgs (numpy.ndarray): Numpy HSV images array of shape (H, W, C) or (N, H, W, C), hsv_imgs (numpy.ndarray): NumPy HSV images array of shape (H, W, C) or (N, H, W, C),
or (C, H, W) or (N, C, H, W) to be converted. or (C, H, W) or (N, C, H, W) to be converted.
Returns: Returns:
rgb_imgs (numpy.ndarray), Numpy RGB image with same shape of hsv_imgs. rgb_imgs (numpy.ndarray), NumPy RGB image with same shape of hsv_imgs.
""" """
return util.hsv_to_rgbs(hsv_imgs, self.is_hwc) return util.hsv_to_rgbs(hsv_imgs, self.is_hwc)
@ -1290,10 +1290,10 @@ class RandomColor:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be color adjusted. img (PIL image): Image to be color adjusted.
Returns: Returns:
img (PIL Image), Color adjusted image. img (PIL image), Color adjusted image.
""" """
return util.random_color(img, self.degrees) return util.random_color(img, self.degrees)
@ -1323,10 +1323,10 @@ class RandomSharpness:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be sharpness adjusted. img (PIL image): Image to be sharpness adjusted.
Returns: Returns:
img (PIL Image), Color adjusted image. img (PIL image), Color adjusted image.
""" """
return util.random_sharpness(img, self.degrees) return util.random_sharpness(img, self.degrees)
@ -1357,10 +1357,10 @@ class AutoContrast:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be augmented with AutoContrast. img (PIL image): Image to be augmented with AutoContrast.
Returns: Returns:
img (PIL Image), Augmented image. img (PIL image), Augmented image.
""" """
return util.auto_contrast(img, self.cutoff, self.ignore) return util.auto_contrast(img, self.cutoff, self.ignore)
@ -1382,10 +1382,10 @@ class Invert:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be color Inverted. img (PIL image): Image to be color Inverted.
Returns: Returns:
img (PIL Image), Color inverted image. img (PIL image), Color inverted image.
""" """
return util.invert_color(img) return util.invert_color(img)
@ -1407,10 +1407,10 @@ class Equalize:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be equalized. img (PIL image): Image to be equalized.
Returns: Returns:
img (PIL Image), Equalized image. img (PIL image), Equalized image.
""" """
return util.equalize(img) return util.equalize(img)
@ -1447,9 +1447,9 @@ class UniformAugment:
Call method. Call method.
Args: Args:
img (PIL Image): Image to be applied transformation. img (PIL image): Image to be applied transformation.
Returns: Returns:
img (PIL Image), Transformed image. img (PIL image), Transformed image.
""" """
return util.uniform_augment(img, self.transforms.copy(), self.num_ops) return util.uniform_augment(img, self.transforms.copy(), self.num_ops)

View File

@ -25,7 +25,7 @@ from PIL import Image, ImageOps, ImageEnhance, __version__
from .utils import Inter from .utils import Inter
augment_error_message = 'img should be PIL Image. Got {}. Use Decode() for encoded data or ToPIL() for decoded data.' augment_error_message = 'img should be PIL image. Got {}. Use Decode() for encoded data or ToPIL() for decoded data.'
def is_pil(img): def is_pil(img):
@ -43,13 +43,13 @@ def is_pil(img):
def is_numpy(img): def is_numpy(img):
""" """
Check if the input image is Numpy format. Check if the input image is NumPy format.
Args: Args:
img: Image to be checked. img: Image to be checked.
Returns: Returns:
Bool, True if input is Numpy image. Bool, True if input is NumPy image.
""" """
return isinstance(img, np.ndarray) return isinstance(img, np.ndarray)
@ -59,19 +59,19 @@ def compose(img, transforms):
Compose a list of transforms and apply on the image. Compose a list of transforms and apply on the image.
Args: Args:
img (numpy.ndarray): An image in Numpy ndarray. img (numpy.ndarray): An image in NumPy ndarray.
transforms (list): A list of transform Class objects to be composed. transforms (list): A list of transform Class objects to be composed.
Returns: Returns:
img (numpy.ndarray), An augmented image in Numpy ndarray. img (numpy.ndarray), An augmented image in NumPy ndarray.
""" """
if is_numpy(img): if is_numpy(img):
for transform in transforms: for transform in transforms:
img = transform(img) img = transform(img)
if is_numpy(img): if is_numpy(img):
return img return img
raise TypeError('img should be Numpy ndarray. Got {}. Append ToTensor() to transforms'.format(type(img))) raise TypeError('img should be NumPy ndarray. Got {}. Append ToTensor() to transforms'.format(type(img)))
raise TypeError('img should be Numpy ndarray. Got {}.'.format(type(img))) raise TypeError('img should be NumPy ndarray. Got {}.'.format(type(img)))
def normalize(img, mean, std): def normalize(img, mean, std):
@ -87,7 +87,7 @@ def normalize(img, mean, std):
img (numpy.ndarray), Normalized image. img (numpy.ndarray), Normalized image.
""" """
if not is_numpy(img): if not is_numpy(img):
raise TypeError('img should be Numpy Image. Got {}'.format(type(img))) raise TypeError('img should be NumPy image. Got {}'.format(type(img)))
num_channels = img.shape[0] # shape is (C, H, W) num_channels = img.shape[0] # shape is (C, H, W)
@ -109,13 +109,13 @@ def normalize(img, mean, std):
def decode(img): def decode(img):
""" """
Decode the input image to PIL Image format in RGB mode. Decode the input image to PIL image format in RGB mode.
Args: Args:
img: Image to be decoded. img: Image to be decoded.
Returns: Returns:
img (PIL Image), Decoded image in RGB mode. img (PIL image), Decoded image in RGB mode.
""" """
try: try:
@ -140,22 +140,22 @@ def hwc_to_chw(img):
""" """
if is_numpy(img): if is_numpy(img):
return img.transpose(2, 0, 1).copy() return img.transpose(2, 0, 1).copy()
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) raise TypeError('img should be NumPy array. Got {}'.format(type(img)))
def to_tensor(img, output_type): def to_tensor(img, output_type):
""" """
Change the input image (PIL Image or Numpy image array) to numpy format. Change the input image (PIL image or NumPy image array) to NumPy format.
Args: Args:
img (Union[PIL Image, numpy.ndarray]): Image to be converted. img (Union[PIL image, numpy.ndarray]): Image to be converted.
output_type: The datatype of the numpy output. e.g. np.float32 output_type: The datatype of the NumPy output. e.g. np.float32
Returns: Returns:
img (numpy.ndarray), Converted image. img (numpy.ndarray), Converted image.
""" """
if not (is_pil(img) or is_numpy(img)): if not (is_pil(img) or is_numpy(img)):
raise TypeError('img should be PIL Image or Numpy array. Got {}'.format(type(img))) raise TypeError('img should be PIL image or NumPy array. Got {}'.format(type(img)))
img = np.asarray(img) img = np.asarray(img)
if img.ndim not in (2, 3): if img.ndim not in (2, 3):
@ -178,7 +178,7 @@ def to_pil(img):
img: Image to be converted. img: Image to be converted.
Returns: Returns:
img (PIL Image), Converted image. img (PIL image), Converted image.
""" """
if not is_pil(img): if not is_pil(img):
return Image.fromarray(img) return Image.fromarray(img)
@ -190,10 +190,10 @@ def horizontal_flip(img):
Flip the input image horizontally. Flip the input image horizontally.
Args: Args:
img (PIL Image): Image to be flipped horizontally. img (PIL image): Image to be flipped horizontally.
Returns: Returns:
img (PIL Image), Horizontally flipped image. img (PIL image), Horizontally flipped image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -206,10 +206,10 @@ def vertical_flip(img):
Flip the input image vertically. Flip the input image vertically.
Args: Args:
img (PIL Image): Image to be flipped vertically. img (PIL image): Image to be flipped vertically.
Returns: Returns:
img (PIL Image), Vertically flipped image. img (PIL image), Vertically flipped image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -222,12 +222,12 @@ def random_horizontal_flip(img, prob):
Randomly flip the input image horizontally. Randomly flip the input image horizontally.
Args: Args:
img (PIL Image): Image to be flipped. img (PIL image): Image to be flipped.
If the given probability is above the random probability, then the image is flipped. If the given probability is above the random probability, then the image is flipped.
prob (float): Probability of the image being flipped. prob (float): Probability of the image being flipped.
Returns: Returns:
img (PIL Image), Converted image. img (PIL image), Converted image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -242,12 +242,12 @@ def random_vertical_flip(img, prob):
Randomly flip the input image vertically. Randomly flip the input image vertically.
Args: Args:
img (PIL Image): Image to be flipped. img (PIL image): Image to be flipped.
If the given probability is above the random probability, then the image is flipped. If the given probability is above the random probability, then the image is flipped.
prob (float): Probability of the image being flipped. prob (float): Probability of the image being flipped.
Returns: Returns:
img (PIL Image), Converted image. img (PIL image), Converted image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -259,10 +259,10 @@ def random_vertical_flip(img, prob):
def crop(img, top, left, height, width): def crop(img, top, left, height, width):
""" """
Crop the input PIL Image. Crop the input PIL image.
Args: Args:
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image, img (PIL image): Image to be cropped. (0,0) denotes the top left corner of the image,
in the directions of (width, height). in the directions of (width, height).
top (int): Vertical component of the top left corner of the crop box. top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box. left (int): Horizontal component of the top left corner of the crop box.
@ -270,7 +270,7 @@ def crop(img, top, left, height, width):
width (int): Width of the crop box. width (int): Width of the crop box.
Returns: Returns:
img (PIL Image), Cropped image. img (PIL image), Cropped image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -280,10 +280,10 @@ def crop(img, top, left, height, width):
def resize(img, size, interpolation=Inter.BILINEAR): def resize(img, size, interpolation=Inter.BILINEAR):
""" """
Resize the input PIL Image to desired size. Resize the input PIL image to desired size.
Args: Args:
img (PIL Image): Image to be resized. img (PIL image): Image to be resized.
size (Union[int, sequence]): The output size of the resized image. size (Union[int, sequence]): The output size of the resized image.
If size is an int, smaller edge of the image will be resized to this value with If size is an int, smaller edge of the image will be resized to this value with
the same image aspect ratio. the same image aspect ratio.
@ -291,7 +291,7 @@ def resize(img, size, interpolation=Inter.BILINEAR):
interpolation (interpolation mode): Image interpolation mode. Default is Inter.BILINEAR = 2. interpolation (interpolation mode): Image interpolation mode. Default is Inter.BILINEAR = 2.
Returns: Returns:
img (PIL Image), Resized image. img (PIL image), Resized image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -317,16 +317,16 @@ def resize(img, size, interpolation=Inter.BILINEAR):
def center_crop(img, size): def center_crop(img, size):
""" """
Crop the input PIL Image at the center to the given size. Crop the input PIL image at the center to the given size.
Args: Args:
img (PIL Image): Image to be cropped. img (PIL image): Image to be cropped.
size (Union[int, tuple]): The size of the crop box. size (Union[int, tuple]): The size of the crop box.
If size is an int, a square crop of size (size, size) is returned. If size is an int, a square crop of size (size, size) is returned.
If size is a sequence of length 2, it should be (height, width). If size is a sequence of length 2, it should be (height, width).
Returns: Returns:
img (PIL Image), Cropped image. img (PIL image), Cropped image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -342,10 +342,10 @@ def center_crop(img, size):
def random_resize_crop(img, size, scale, ratio, interpolation=Inter.BILINEAR, max_attempts=10): def random_resize_crop(img, size, scale, ratio, interpolation=Inter.BILINEAR, max_attempts=10):
""" """
Crop the input PIL Image to a random size and aspect ratio. Crop the input PIL image to a random size and aspect ratio.
Args: Args:
img (PIL Image): Image to be randomly cropped and resized. img (PIL image): Image to be randomly cropped and resized.
size (Union[int, sequence]): The size of the output image. size (Union[int, sequence]): The size of the output image.
If size is an int, a square crop of size (size, size) is returned. If size is an int, a square crop of size (size, size) is returned.
If size is a sequence of length 2, it should be (height, width). If size is a sequence of length 2, it should be (height, width).
@ -356,7 +356,7 @@ def random_resize_crop(img, size, scale, ratio, interpolation=Inter.BILINEAR, ma
If exceeded, fall back to use center_crop instead. If exceeded, fall back to use center_crop instead.
Returns: Returns:
img (PIL Image), Randomly cropped and resized image. img (PIL image), Randomly cropped and resized image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -412,10 +412,10 @@ def random_resize_crop(img, size, scale, ratio, interpolation=Inter.BILINEAR, ma
def random_crop(img, size, padding, pad_if_needed, fill_value, padding_mode): def random_crop(img, size, padding, pad_if_needed, fill_value, padding_mode):
""" """
Crop the input PIL Image at a random location. Crop the input PIL image at a random location.
Args: Args:
img (PIL Image): Image to be randomly cropped. img (PIL image): Image to be randomly cropped.
size (Union[int, sequence]): The output size of the cropped image. size (Union[int, sequence]): The output size of the cropped image.
If size is an int, a square crop of size (size, size) is returned. If size is an int, a square crop of size (size, size) is returned.
If size is a sequence of length 2, it should be (height, width). If size is a sequence of length 2, it should be (height, width).
@ -441,7 +441,7 @@ def random_crop(img, size, padding, pad_if_needed, fill_value, padding_mode):
value of edge value of edge
Returns: Returns:
PIL Image, Cropped image. PIL image, Cropped image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -483,12 +483,12 @@ def adjust_brightness(img, brightness_factor):
Adjust brightness of an image. Adjust brightness of an image.
Args: Args:
img (PIL Image): Image to be adjusted. img (PIL image): Image to be adjusted.
brightness_factor (float): A non negative number indicated the factor by which brightness_factor (float): A non negative number indicated the factor by which
the brightness is adjusted. 0 gives a black image, 1 gives the original. the brightness is adjusted. 0 gives a black image, 1 gives the original.
Returns: Returns:
img (PIL Image), Brightness adjusted image. img (PIL image), Brightness adjusted image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -503,12 +503,12 @@ def adjust_contrast(img, contrast_factor):
Adjust contrast of an image. Adjust contrast of an image.
Args: Args:
img (PIL Image): PIL Image to be adjusted. img (PIL image): PIL image to be adjusted.
contrast_factor (float): A non negative number indicated the factor by which contrast_factor (float): A non negative number indicated the factor by which
the contrast is adjusted. 0 gives a solid gray image, 1 gives the original. the contrast is adjusted. 0 gives a solid gray image, 1 gives the original.
Returns: Returns:
img (PIL Image), Contrast adjusted image. img (PIL image), Contrast adjusted image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -523,13 +523,13 @@ def adjust_saturation(img, saturation_factor):
Adjust saturation of an image. Adjust saturation of an image.
Args: Args:
img (PIL Image): PIL Image to be adjusted. img (PIL image): PIL image to be adjusted.
saturation_factor (float): A non negative number indicated the factor by which saturation_factor (float): A non negative number indicated the factor by which
the saturation is adjusted. 0 will give a black and white image, 1 will the saturation is adjusted. 0 will give a black and white image, 1 will
give the original. give the original.
Returns: Returns:
img (PIL Image), Saturation adjusted image. img (PIL image), Saturation adjusted image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -544,7 +544,7 @@ def adjust_hue(img, hue_factor):
Adjust hue of an image. The Hue is changed by changing the HSV values after image is converted to HSV. Adjust hue of an image. The Hue is changed by changing the HSV values after image is converted to HSV.
Args: Args:
img (PIL Image): PIL Image to be adjusted. img (PIL image): PIL image to be adjusted.
hue_factor (float): Amount to shift the Hue channel. Value should be in hue_factor (float): Amount to shift the Hue channel. Value should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel. This [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel. This
is because Hue wraps around when rotated 360 degrees. is because Hue wraps around when rotated 360 degrees.
@ -552,7 +552,7 @@ def adjust_hue(img, hue_factor):
will give an image with complementary colors . will give an image with complementary colors .
Returns: Returns:
img (PIL Image), Hue adjusted image. img (PIL image), Hue adjusted image.
""" """
image = img image = img
image_hue_factor = hue_factor image_hue_factor = hue_factor
@ -580,17 +580,17 @@ def adjust_hue(img, hue_factor):
def to_type(img, output_type): def to_type(img, output_type):
""" """
Convert the Numpy image array to desired numpy dtype. Convert the NumPy image array to desired NumPy dtype.
Args: Args:
img (numpy): Numpy image to cast to desired numpy dtype. img (numpy): NumPy image to cast to desired NumPy dtype.
output_type (numpy datatype): Numpy dtype to cast to. output_type (Numpy datatype): NumPy dtype to cast to.
Returns: Returns:
img (numpy.ndarray), Converted image. img (numpy.ndarray), Converted image.
""" """
if not is_numpy(img): if not is_numpy(img):
raise TypeError('img should be Numpy Image. Got {}'.format(type(img))) raise TypeError('img should be NumPy image. Got {}'.format(type(img)))
return img.astype(output_type) return img.astype(output_type)
@ -600,7 +600,7 @@ def rotate(img, angle, resample, expand, center, fill_value):
Rotate the input PIL image by angle. Rotate the input PIL image by angle.
Args: Args:
img (PIL Image): Image to be rotated. img (PIL image): Image to be rotated.
angle (int or float): Rotation angle in degrees, counter-clockwise. angle (int or float): Rotation angle in degrees, counter-clockwise.
resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter. resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter.
If omitted, or if the image has mode "1" or "P", it is set to be Inter.NEAREST. If omitted, or if the image has mode "1" or "P", it is set to be Inter.NEAREST.
@ -615,7 +615,7 @@ def rotate(img, angle, resample, expand, center, fill_value):
If it is an int, it is used for all RGB channels. If it is an int, it is used for all RGB channels.
Returns: Returns:
img (PIL Image), Rotated image. img (PIL image), Rotated image.
https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.rotate https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.rotate
""" """
@ -633,7 +633,7 @@ def random_color_adjust(img, brightness, contrast, saturation, hue):
Randomly adjust the brightness, contrast, saturation, and hue of an image. Randomly adjust the brightness, contrast, saturation, and hue of an image.
Args: Args:
img (PIL Image): Image to have its color adjusted randomly. img (PIL image): Image to have its color adjusted randomly.
brightness (Union[float, tuple]): Brightness adjustment factor. Cannot be negative. brightness (Union[float, tuple]): Brightness adjustment factor. Cannot be negative.
If it is a float, the factor is uniformly chosen from the range [max(0, 1-brightness), 1+brightness]. If it is a float, the factor is uniformly chosen from the range [max(0, 1-brightness), 1+brightness].
If it is a sequence, it should be [min, max] for the range. If it is a sequence, it should be [min, max] for the range.
@ -648,7 +648,7 @@ def random_color_adjust(img, brightness, contrast, saturation, hue):
If it is a sequence, it should be [min, max] where -0.5 <= min <= max <= 0.5. If it is a sequence, it should be [min, max] where -0.5 <= min <= max <= 0.5.
Returns: Returns:
img (PIL Image), Image after random adjustment of its color. img (PIL image), Image after random adjustment of its color.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -695,7 +695,7 @@ def random_rotation(img, degrees, resample, expand, center, fill_value):
Rotate the input PIL image by a random angle. Rotate the input PIL image by a random angle.
Args: Args:
img (PIL Image): Image to be rotated. img (PIL image): Image to be rotated.
degrees (Union[int, float, sequence]): Range of random rotation degrees. degrees (Union[int, float, sequence]): Range of random rotation degrees.
If degrees is a number, the range will be converted to (-degrees, degrees). If degrees is a number, the range will be converted to (-degrees, degrees).
If degrees is a sequence, it should be (min, max). If degrees is a sequence, it should be (min, max).
@ -712,7 +712,7 @@ def random_rotation(img, degrees, resample, expand, center, fill_value):
If it is an int, it is used for all RGB channels. If it is an int, it is used for all RGB channels.
Returns: Returns:
img (PIL Image), Rotated image. img (PIL image), Rotated image.
https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.rotate https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.rotate
""" """
@ -788,7 +788,7 @@ def five_crop(img, size):
Generate 5 cropped images (one central and four corners). Generate 5 cropped images (one central and four corners).
Args: Args:
img (PIL Image): PIL Image to be cropped. img (PIL image): PIL image to be cropped.
size (Union[int, sequence]): The output size of the crop. size (Union[int, sequence]): The output size of the crop.
If size is an int, a square crop of size (size, size) is returned. If size is an int, a square crop of size (size, size) is returned.
If size is a sequence of length 2, it should be (height, width). If size is a sequence of length 2, it should be (height, width).
@ -807,7 +807,7 @@ def five_crop(img, size):
else: else:
raise TypeError("Size should be a single number or a list/tuple (h, w) of length 2.") raise TypeError("Size should be a single number or a list/tuple (h, w) of length 2.")
# PIL Image.size returns in (width, height) order # PIL image.size returns in (width, height) order
img_width, img_height = img.size img_width, img_height = img.size
crop_height, crop_width = size crop_height, crop_width = size
if crop_height > img_height or crop_width > img_width: if crop_height > img_height or crop_width > img_width:
@ -828,7 +828,7 @@ def ten_crop(img, size, use_vertical_flip=False):
The default is horizontal flipping, use_vertical_flip=False. The default is horizontal flipping, use_vertical_flip=False.
Args: Args:
img (PIL Image): PIL Image to be cropped. img (PIL image): PIL image to be cropped.
size (Union[int, sequence]): The output size of the crop. size (Union[int, sequence]): The output size of the crop.
If size is an int, a square crop of size (size, size) is returned. If size is an int, a square crop of size (size, size) is returned.
If size is a sequence of length 2, it should be (height, width). If size is a sequence of length 2, it should be (height, width).
@ -866,11 +866,11 @@ def grayscale(img, num_output_channels):
Convert the input PIL image to grayscale image. Convert the input PIL image to grayscale image.
Args: Args:
img (PIL Image): PIL image to be converted to grayscale. img (PIL image): PIL image to be converted to grayscale.
num_output_channels (int): Number of channels of the output grayscale image (1 or 3). num_output_channels (int): Number of channels of the output grayscale image (1 or 3).
Returns: Returns:
img (PIL Image), grayscaled image. img (PIL image), grayscaled image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -894,7 +894,7 @@ def pad(img, padding, fill_value, padding_mode):
Pads the image according to padding parameters. Pads the image according to padding parameters.
Args: Args:
img (PIL Image): Image to be padded. img (PIL image): Image to be padded.
padding (Union[int, sequence], optional): The number of pixels to pad the image. padding (Union[int, sequence], optional): The number of pixels to pad the image.
If a single number is provided, it pads all borders with this value. If a single number is provided, it pads all borders with this value.
If a tuple or list of 2 values are provided, it pads the (left and top) If a tuple or list of 2 values are provided, it pads the (left and top)
@ -915,7 +915,7 @@ def pad(img, padding, fill_value, padding_mode):
value of edge value of edge
Returns: Returns:
img (PIL Image), Padded image. img (PIL image), Padded image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError(augment_error_message.format(type(img))) raise TypeError(augment_error_message.format(type(img)))
@ -990,16 +990,16 @@ def get_perspective_params(img, distortion_scale):
def perspective(img, start_points, end_points, interpolation=Inter.BICUBIC): def perspective(img, start_points, end_points, interpolation=Inter.BICUBIC):
""" """
Apply perspective transformation to the input PIL Image. Apply perspective transformation to the input PIL image.
Args: Args:
img (PIL Image): PIL Image to be applied perspective transformation. img (PIL image): PIL image to be applied perspective transformation.
start_points (list): List of [top_left, top_right, bottom_right, bottom_left] of the original image. start_points (list): List of [top_left, top_right, bottom_right, bottom_left] of the original image.
end_points: List of [top_left, top_right, bottom_right, bottom_left] of the transformed image. end_points: List of [top_left, top_right, bottom_right, bottom_left] of the transformed image.
interpolation (interpolation mode): Image interpolation mode, Default is Inter.BICUBIC = 3. interpolation (interpolation mode): Image interpolation mode, Default is Inter.BICUBIC = 3.
Returns: Returns:
img (PIL Image), Image after being perspectively transformed. img (PIL image), Image after being perspectively transformed.
""" """
def _input_to_coeffs(original_points, transformed_points): def _input_to_coeffs(original_points, transformed_points):
@ -1028,7 +1028,7 @@ def get_erase_params(np_img, scale, ratio, value, bounded, max_attempts):
"""Helper function to get parameters for RandomErasing/ Cutout. """Helper function to get parameters for RandomErasing/ Cutout.
""" """
if not is_numpy(np_img): if not is_numpy(np_img):
raise TypeError('img should be Numpy array. Got {}'.format(type(np_img))) raise TypeError('img should be NumPy array. Got {}'.format(type(np_img)))
image_c, image_h, image_w = np_img.shape image_c, image_h, image_w = np_img.shape
area = image_h * image_w area = image_h * image_w
@ -1076,10 +1076,10 @@ def get_erase_params(np_img, scale, ratio, value, bounded, max_attempts):
def erase(np_img, i, j, height, width, erase_value, inplace=False): def erase(np_img, i, j, height, width, erase_value, inplace=False):
""" """
Erase the pixels, within a selected rectangle region, to the given value. Applied on the input Numpy image array. Erase the pixels, within a selected rectangle region, to the given value. Applied on the input NumPy image array.
Args: Args:
np_img (numpy.ndarray): Numpy image array of shape (C, H, W) to be erased. np_img (numpy.ndarray): NumPy image array of shape (C, H, W) to be erased.
i (int): The height component of the top left corner (height, width). i (int): The height component of the top left corner (height, width).
j (int): The width component of the top left corner (height, width). j (int): The width component of the top left corner (height, width).
height (int): Height of the erased region. height (int): Height of the erased region.
@ -1088,10 +1088,10 @@ def erase(np_img, i, j, height, width, erase_value, inplace=False):
inplace (bool, optional): Apply this transform inplace. Default is False. inplace (bool, optional): Apply this transform inplace. Default is False.
Returns: Returns:
np_img (numpy.ndarray), Erased Numpy image array. np_img (numpy.ndarray), Erased NumPy image array.
""" """
if not is_numpy(np_img): if not is_numpy(np_img):
raise TypeError('img should be Numpy array. Got {}'.format(type(np_img))) raise TypeError('img should be NumPy array. Got {}'.format(type(np_img)))
if not inplace: if not inplace:
np_img = np_img.copy() np_img = np_img.copy()
@ -1102,27 +1102,27 @@ def erase(np_img, i, j, height, width, erase_value, inplace=False):
def linear_transform(np_img, transformation_matrix, mean_vector): def linear_transform(np_img, transformation_matrix, mean_vector):
""" """
Apply linear transformation to the input Numpy image array, given a square transformation matrix and a mean_vector. Apply linear transformation to the input NumPy image array, given a square transformation matrix and a mean_vector.
The transformation first flattens the input array and subtract mean_vector from it, then computes the The transformation first flattens the input array and subtract mean_vector from it, then computes the
dot product with the transformation matrix, and reshapes it back to its original shape. dot product with the transformation matrix, and reshapes it back to its original shape.
Args: Args:
np_img (numpy.ndarray): Numpy image array of shape (C, H, W) to be linear transformed. np_img (numpy.ndarray): NumPy image array of shape (C, H, W) to be linear transformed.
transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), D = C x H x W. transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), D = C x H x W.
mean_vector (numpy.ndarray): a numpy ndarray of shape (D,) where D = C x H x W. mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where D = C x H x W.
Returns: Returns:
np_img (numpy.ndarray), Linear transformed image. np_img (numpy.ndarray), Linear transformed image.
""" """
if not is_numpy(np_img): if not is_numpy(np_img):
raise TypeError('img should be Numpy array. Got {}'.format(type(np_img))) raise TypeError('img should be NumPy array. Got {}'.format(type(np_img)))
if transformation_matrix.shape[0] != transformation_matrix.shape[1]: if transformation_matrix.shape[0] != transformation_matrix.shape[1]:
raise ValueError("transformation_matrix should be a square matrix. " raise ValueError("transformation_matrix should be a square matrix. "
"Got shape {} instead".format(transformation_matrix.shape)) "Got shape {} instead".format(transformation_matrix.shape))
if np.prod(np_img.shape) != transformation_matrix.shape[0]: if np.prod(np_img.shape) != transformation_matrix.shape[0]:
raise ValueError("transformation_matrix shape {0} not compatible with " raise ValueError("transformation_matrix shape {0} not compatible with "
"Numpy Image shape {1}.".format(transformation_matrix.shape, np_img.shape)) "Numpy image shape {1}.".format(transformation_matrix.shape, np_img.shape))
if mean_vector.shape[0] != transformation_matrix.shape[0]: if mean_vector.shape[0] != transformation_matrix.shape[0]:
raise ValueError("mean_vector length {0} should match either one dimension of the square " raise ValueError("mean_vector length {0} should match either one dimension of the square "
"transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape)) "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
@ -1136,7 +1136,7 @@ def random_affine(img, angle, translations, scale, shear, resample, fill_value=0
Applies a random Affine transformation on the input PIL image. Applies a random Affine transformation on the input PIL image.
Args: Args:
img (PIL Image): Image to be applied affine transformation. img (PIL image): Image to be applied affine transformation.
angle (Union[int, float]): Rotation angle in degrees, clockwise. angle (Union[int, float]): Rotation angle in degrees, clockwise.
translations (sequence): Translations in horizontal and vertical axis. translations (sequence): Translations in horizontal and vertical axis.
scale (float): Scale parameter, a single number. scale (float): Scale parameter, a single number.
@ -1147,7 +1147,7 @@ def random_affine(img, angle, translations, scale, shear, resample, fill_value=0
If None, no filling is performed. If None, no filling is performed.
Returns: Returns:
img (PIL Image), Randomly affine transformed image. img (PIL image), Randomly affine transformed image.
""" """
if not is_pil(img): if not is_pil(img):
@ -1253,13 +1253,13 @@ def mix_up_single(batch_size, img, label, alpha=0.2):
Args: Args:
batch_size (int): the batch size of dataset. batch_size (int): the batch size of dataset.
img (numpy.ndarray): numpy Image to be applied mix up transformation. img (numpy.ndarray): NumPy image to be applied mix up transformation.
label (numpy.ndarray): numpy label to be applied mix up transformation. label (numpy.ndarray): NumPy label to be applied mix up transformation.
alpha (float): the mix up rate. alpha (float): the mix up rate.
Returns: Returns:
mix_img (numpy.ndarray): numpy Image after being applied mix up transformation. mix_img (numpy.ndarray): NumPy image after being applied mix up transformation.
mix_label (numpy.ndarray): numpy label after being applied mix up transformation. mix_label (numpy.ndarray): NumPy label after being applied mix up transformation.
""" """
def cir_shift(data): def cir_shift(data):
@ -1284,13 +1284,13 @@ def mix_up_muti(tmp, batch_size, img, label, alpha=0.2):
Args: Args:
tmp (class object): mainly for saving the tmp parameter. tmp (class object): mainly for saving the tmp parameter.
batch_size (int): the batch size of dataset. batch_size (int): the batch size of dataset.
img (numpy.ndarray): numpy Image to be applied mix up transformation. img (numpy.ndarray): NumPy image to be applied mix up transformation.
label (numpy.ndarray): numpy label to be applied mix up transformation. label (numpy.ndarray): NumPy label to be applied mix up transformation.
alpha (float): refer to the mix up rate. alpha (float): refer to the mix up rate.
Returns: Returns:
mix_img (numpy.ndarray): numpy Image after being applied mix up transformation. mix_img (numpy.ndarray): NumPy image after being applied mix up transformation.
mix_label (numpy.ndarray): numpy label after being applied mix up transformation. mix_label (numpy.ndarray): NumPy label after being applied mix up transformation.
""" """
lam = np.random.beta(alpha, alpha, batch_size) lam = np.random.beta(alpha, alpha, batch_size)
if tmp.is_first: if tmp.is_first:
@ -1313,11 +1313,11 @@ def rgb_to_hsv(np_rgb_img, is_hwc):
Convert RGB img to HSV img. Convert RGB img to HSV img.
Args: Args:
np_rgb_img (numpy.ndarray): Numpy RGB image array of shape (H, W, C) or (C, H, W) to be converted. np_rgb_img (numpy.ndarray): NumPy RGB image array of shape (H, W, C) or (C, H, W) to be converted.
is_hwc (Bool): If True, the shape of np_hsv_img is (H, W, C), otherwise must be (C, H, W). is_hwc (Bool): If True, the shape of np_hsv_img is (H, W, C), otherwise must be (C, H, W).
Returns: Returns:
np_hsv_img (numpy.ndarray), Numpy HSV image with same type of np_rgb_img. np_hsv_img (numpy.ndarray), NumPy HSV image with same type of np_rgb_img.
""" """
if is_hwc: if is_hwc:
r, g, b = np_rgb_img[:, :, 0], np_rgb_img[:, :, 1], np_rgb_img[:, :, 2] r, g, b = np_rgb_img[:, :, 0], np_rgb_img[:, :, 1], np_rgb_img[:, :, 2]
@ -1338,16 +1338,16 @@ def rgb_to_hsvs(np_rgb_imgs, is_hwc):
Convert RGB imgs to HSV imgs. Convert RGB imgs to HSV imgs.
Args: Args:
np_rgb_imgs (numpy.ndarray): Numpy RGB images array of shape (H, W, C) or (N, H, W, C), np_rgb_imgs (numpy.ndarray): NumPy RGB images array of shape (H, W, C) or (N, H, W, C),
or (C, H, W) or (N, C, H, W) to be converted. or (C, H, W) or (N, C, H, W) to be converted.
is_hwc (Bool): If True, the shape of np_rgb_imgs is (H, W, C) or (N, H, W, C); is_hwc (Bool): If True, the shape of np_rgb_imgs is (H, W, C) or (N, H, W, C);
If False, the shape of np_rgb_imgs is (C, H, W) or (N, C, H, W). If False, the shape of np_rgb_imgs is (C, H, W) or (N, C, H, W).
Returns: Returns:
np_hsv_imgs (numpy.ndarray), Numpy HSV images with same type of np_rgb_imgs. np_hsv_imgs (numpy.ndarray), NumPy HSV images with same type of np_rgb_imgs.
""" """
if not is_numpy(np_rgb_imgs): if not is_numpy(np_rgb_imgs):
raise TypeError('img should be Numpy Image. Got {}'.format(type(np_rgb_imgs))) raise TypeError('img should be NumPy image. Got {}'.format(type(np_rgb_imgs)))
shape_size = len(np_rgb_imgs.shape) shape_size = len(np_rgb_imgs.shape)
@ -1380,11 +1380,11 @@ def hsv_to_rgb(np_hsv_img, is_hwc):
Convert HSV img to RGB img. Convert HSV img to RGB img.
Args: Args:
np_hsv_img (numpy.ndarray): Numpy HSV image array of shape (H, W, C) or (C, H, W) to be converted. np_hsv_img (numpy.ndarray): NumPy HSV image array of shape (H, W, C) or (C, H, W) to be converted.
is_hwc (Bool): If True, the shape of np_hsv_img is (H, W, C), otherwise must be (C, H, W). is_hwc (Bool): If True, the shape of np_hsv_img is (H, W, C), otherwise must be (C, H, W).
Returns: Returns:
np_rgb_img (numpy.ndarray), Numpy HSV image with same shape of np_hsv_img. np_rgb_img (numpy.ndarray), NumPy HSV image with same shape of np_hsv_img.
""" """
if is_hwc: if is_hwc:
h, s, v = np_hsv_img[:, :, 0], np_hsv_img[:, :, 1], np_hsv_img[:, :, 2] h, s, v = np_hsv_img[:, :, 0], np_hsv_img[:, :, 1], np_hsv_img[:, :, 2]
@ -1406,16 +1406,16 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
Convert HSV imgs to RGB imgs. Convert HSV imgs to RGB imgs.
Args: Args:
np_hsv_imgs (numpy.ndarray): Numpy HSV images array of shape (H, W, C) or (N, H, W, C), np_hsv_imgs (numpy.ndarray): NumPy HSV images array of shape (H, W, C) or (N, H, W, C),
or (C, H, W) or (N, C, H, W) to be converted. or (C, H, W) or (N, C, H, W) to be converted.
is_hwc (Bool): If True, the shape of np_hsv_imgs is (H, W, C) or (N, H, W, C); is_hwc (Bool): If True, the shape of np_hsv_imgs is (H, W, C) or (N, H, W, C);
If False, the shape of np_hsv_imgs is (C, H, W) or (N, C, H, W). If False, the shape of np_hsv_imgs is (C, H, W) or (N, C, H, W).
Returns: Returns:
np_rgb_imgs (numpy.ndarray), Numpy RGB images with same type of np_hsv_imgs. np_rgb_imgs (numpy.ndarray), NumPy RGB images with same type of np_hsv_imgs.
""" """
if not is_numpy(np_hsv_imgs): if not is_numpy(np_hsv_imgs):
raise TypeError('img should be Numpy Image. Got {}'.format(type(np_hsv_imgs))) raise TypeError('img should be NumPy image. Got {}'.format(type(np_hsv_imgs)))
shape_size = len(np_hsv_imgs.shape) shape_size = len(np_hsv_imgs.shape)
@ -1448,16 +1448,16 @@ def random_color(img, degrees):
Adjust the color of the input PIL image by a random degree. Adjust the color of the input PIL image by a random degree.
Args: Args:
img (PIL Image): Image to be color adjusted. img (PIL image): Image to be color adjusted.
degrees (sequence): Range of random color adjustment degrees. degrees (sequence): Range of random color adjustment degrees.
It should be in (min, max) format (default=(0.1,1.9)). It should be in (min, max) format (default=(0.1,1.9)).
Returns: Returns:
img (PIL Image), Color adjusted image. img (PIL image), Color adjusted image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL image. Got {}'.format(type(img)))
v = (degrees[1] - degrees[0]) * random.random() + degrees[0] v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
return ImageEnhance.Color(img).enhance(v) return ImageEnhance.Color(img).enhance(v)
@ -1468,16 +1468,16 @@ def random_sharpness(img, degrees):
Adjust the sharpness of the input PIL image by a random degree. Adjust the sharpness of the input PIL image by a random degree.
Args: Args:
img (PIL Image): Image to be sharpness adjusted. img (PIL image): Image to be sharpness adjusted.
degrees (sequence): Range of random sharpness adjustment degrees. degrees (sequence): Range of random sharpness adjustment degrees.
It should be in (min, max) format (default=(0.1,1.9)). It should be in (min, max) format (default=(0.1,1.9)).
Returns: Returns:
img (PIL Image), Sharpness adjusted image. img (PIL image), Sharpness adjusted image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL image. Got {}'.format(type(img)))
v = (degrees[1] - degrees[0]) * random.random() + degrees[0] v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
return ImageEnhance.Sharpness(img).enhance(v) return ImageEnhance.Sharpness(img).enhance(v)
@ -1488,17 +1488,17 @@ def auto_contrast(img, cutoff, ignore):
Automatically maximize the contrast of the input PIL image. Automatically maximize the contrast of the input PIL image.
Args: Args:
img (PIL Image): Image to be augmented with AutoContrast. img (PIL image): Image to be augmented with AutoContrast.
cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0). cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0).
ignore (Union[int, sequence], optional): Pixel values to ignore (default=None). ignore (Union[int, sequence], optional): Pixel values to ignore (default=None).
Returns: Returns:
img (PIL Image), Augmented image. img (PIL image), Augmented image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL image. Got {}'.format(type(img)))
return ImageOps.autocontrast(img, cutoff, ignore) return ImageOps.autocontrast(img, cutoff, ignore)
@ -1508,15 +1508,15 @@ def invert_color(img):
Invert colors of input PIL image. Invert colors of input PIL image.
Args: Args:
img (PIL Image): Image to be color inverted. img (PIL image): Image to be color inverted.
Returns: Returns:
img (PIL Image), Color inverted image. img (PIL image), Color inverted image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL image. Got {}'.format(type(img)))
return ImageOps.invert(img) return ImageOps.invert(img)
@ -1526,15 +1526,15 @@ def equalize(img):
Equalize the histogram of input PIL image. Equalize the histogram of input PIL image.
Args: Args:
img (PIL Image): Image to be equalized img (PIL image): Image to be equalized
Returns: Returns:
img (PIL Image), Equalized image. img (PIL image), Equalized image.
""" """
if not is_pil(img): if not is_pil(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL image. Got {}'.format(type(img)))
return ImageOps.equalize(img) return ImageOps.equalize(img)

View File

@ -610,7 +610,7 @@ def check_bounding_box_augment_cpp(method):
def check_auto_contrast(method): def check_auto_contrast(method):
"""Wrapper method to check the parameters of AutoContrast ops (python and cpp).""" """Wrapper method to check the parameters of AutoContrast ops (Python and C++)."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -631,7 +631,7 @@ def check_auto_contrast(method):
def check_uniform_augment_py(method): def check_uniform_augment_py(method):
"""Wrapper method to check the parameters of python UniformAugment op.""" """Wrapper method to check the parameters of Python UniformAugment op."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
@ -656,7 +656,7 @@ def check_uniform_augment_py(method):
def check_positive_degrees(method): def check_positive_degrees(method):
"""A wrapper method to check degrees parameter in RandomSharpness and RandomColor ops (python and cpp)""" """A wrapper method to check degrees parameter in RandomSharpness and RandomColor ops (Python and C++)"""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):

View File

@ -19,12 +19,15 @@ import mindspore.dataset.transforms.c_transforms as ops
def test_random_choice(): def test_random_choice():
"""
Test RandomChoice op
"""
ds.config.set_seed(0) ds.config.set_seed(0)
def test_config(arr, op_list): def test_config(arr, op_list):
try: try:
data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False) data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False)
data = data.map(input_columns=["col"], operations=ops.RandomChoice(op_list)) data = data.map(operations=ops.RandomChoice(op_list), input_columns=["col"])
res = [] res = []
for i in data.create_dict_iterator(num_epochs=1): for i in data.create_dict_iterator(num_epochs=1):
res.append(i["col"].tolist()) res.append(i["col"].tolist())
@ -32,15 +35,16 @@ def test_random_choice():
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
return str(e) return str(e)
# test whether a op would be randomly chosen. In order to prevent random failure, both results need to be checked # Test whether an operation would be randomly chosen.
# In order to prevent random failure, both results need to be checked.
res1 = test_config([[0, 1, 2]], [ops.PadEnd([4], 0), ops.Slice([0, 2])]) res1 = test_config([[0, 1, 2]], [ops.PadEnd([4], 0), ops.Slice([0, 2])])
assert res1 in [[[0, 1, 2, 0]], [[0, 2]]] assert res1 in [[[0, 1, 2, 0]], [[0, 2]]]
# test nested structure # Test nested structure
res2 = test_config([[0, 1, 2]], [ops.Compose([ops.Duplicate(), ops.Concatenate()]), res2 = test_config([[0, 1, 2]], [ops.Compose([ops.Duplicate(), ops.Concatenate()]),
ops.Compose([ops.Slice([0, 1]), ops.OneHot(2)])]) ops.Compose([ops.Slice([0, 1]), ops.OneHot(2)])])
assert res2 in [[[[1, 0], [0, 1]]], [[0, 1, 2, 0, 1, 2]]] assert res2 in [[[[1, 0], [0, 1]]], [[0, 1, 2, 0, 1, 2]]]
# test random_choice where there is only 1 op # Test RandomChoice where there is only 1 operation
assert test_config([[4, 3], [2, 1]], [ops.Slice([0])]) == [[4], [2]] assert test_config([[4, 3], [2, 1]], [ops.Slice([0])]) == [[4], [2]]

View File

@ -89,7 +89,7 @@ def test_five_crop_error_msg():
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
for _ in data: for _ in data:
pass pass
error_msg = "TypeError: img should be PIL Image or Numpy array. Got <class 'tuple'>" error_msg = "TypeError: img should be PIL image or NumPy array. Got <class 'tuple'>"
# error msg comes from ToTensor() # error msg comes from ToTensor()
assert error_msg in str(info.value) assert error_msg in str(info.value)

View File

@ -500,7 +500,7 @@ def test_random_crop_09():
data.create_dict_iterator(num_epochs=1).get_next() data.create_dict_iterator(num_epochs=1).get_next()
except RuntimeError as e: except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e))) logger.info("Got an exception in DE: {}".format(str(e)))
assert "should be PIL Image" in str(e) assert "should be PIL image" in str(e)
def test_random_crop_comp(plot=False): def test_random_crop_comp(plot=False):
""" """

View File

@ -175,7 +175,7 @@ def test_resize_with_bbox_op_bad_c():
def test_resize_with_bbox_op_params_outside_of_interpolation_dict(): def test_resize_with_bbox_op_params_outside_of_interpolation_dict():
""" """
Test passing in a invalid key for interpolation Test passing in an invalid key for interpolation
""" """
logger.info("test_resize_with_bbox_op_params_outside_of_interpolation_dict") logger.info("test_resize_with_bbox_op_params_outside_of_interpolation_dict")

View File

@ -174,7 +174,7 @@ def test_ten_crop_wrong_img_error_msg():
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
data.create_tuple_iterator(num_epochs=1).get_next() data.create_tuple_iterator(num_epochs=1).get_next()
error_msg = "TypeError: img should be PIL Image or Numpy array. Got <class 'tuple'>" error_msg = "TypeError: img should be PIL image or NumPy array. Got <class 'tuple'>"
# error msg comes from ToTensor() # error msg comes from ToTensor()
assert error_msg in str(info.value) assert error_msg in str(info.value)