fix bugs of execute & update doc description in minddata

This commit is contained in:
luoyang 2021-02-05 17:11:52 +08:00
parent 24e528112f
commit cd86788cf3
19 changed files with 163 additions and 99 deletions

View File

@ -51,12 +51,19 @@ Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor
Status rc = dataset::Tensor::CreateFromMemory(dataset::TensorShape(input.Shape()),
MSTypeToDEType(static_cast<TypeId>(input.DataType())),
(const uchar *)(input.Data().get()), input.DataSize(), &de_tensor);
RETURN_IF_NOT_OK(rc);
if (rc.IsError()) {
MS_LOG(ERROR) << rc;
RETURN_IF_NOT_OK(rc);
}
// Apply transforms on tensor
for (auto &t : transforms) {
std::shared_ptr<dataset::Tensor> de_output;
RETURN_IF_NOT_OK(t->Compute(de_tensor, &de_output));
Status rc_ = t->Compute(de_tensor, &de_output);
if (rc_.IsError()) {
MS_LOG(ERROR) << rc_;
RETURN_IF_NOT_OK(rc_);
}
// For next transform
de_tensor = std::move(de_output);
@ -90,14 +97,21 @@ Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::
Status rc = dataset::Tensor::CreateFromMemory(dataset::TensorShape(tensor.Shape()),
MSTypeToDEType(static_cast<TypeId>(tensor.DataType())),
(const uchar *)(tensor.Data().get()), tensor.DataSize(), &de_tensor);
RETURN_IF_NOT_OK(rc);
if (rc.IsError()) {
MS_LOG(ERROR) << rc;
RETURN_IF_NOT_OK(rc);
}
de_tensor_list.emplace_back(std::move(de_tensor));
}
// Apply transforms on tensor
for (auto &t : transforms) {
TensorRow de_output_list;
RETURN_IF_NOT_OK(t->Compute(de_tensor_list, &de_output_list));
Status rc = t->Compute(de_tensor_list, &de_output_list);
if (rc.IsError()) {
MS_LOG(ERROR) << rc;
RETURN_IF_NOT_OK(rc);
}
// For next transform
de_tensor_list = std::move(de_output_list);
}

View File

@ -31,12 +31,7 @@ PYBIND_REGISTER(Execute, 0, ([](const py::module *m) {
.def("__call__",
[](Execute &self, const std::shared_ptr<Tensor> &de_tensor) {
auto ms_tensor = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
Status rc = self(ms_tensor, &ms_tensor);
if (rc.IsError()) {
THROW_IF_ERROR([&rc]() {
RETURN_STATUS_UNEXPECTED("Failed to execute transform op, " + rc.ToString());
}());
}
THROW_IF_ERROR(self(ms_tensor, &ms_tensor));
std::shared_ptr<dataset::Tensor> de_output_tensor;
dataset::Tensor::CreateFromMemory(dataset::TensorShape(ms_tensor.Shape()),
MSTypeToDEType(static_cast<TypeId>(ms_tensor.DataType())),
@ -51,11 +46,7 @@ PYBIND_REGISTER(Execute, 0, ([](const py::module *m) {
auto ms_tensor = mindspore::MSTensor(std::make_shared<DETensor>(tensor));
ms_input_tensor_list.emplace_back(std::move(ms_tensor));
}
Status rc = self(ms_input_tensor_list, &ms_output_tensor_list);
if (rc.IsError()) {
THROW_IF_ERROR(
[&rc]() { RETURN_STATUS_UNEXPECTED("Failed to execute transform op, " + rc.ToString()); }());
}
THROW_IF_ERROR(self(ms_input_tensor_list, &ms_output_tensor_list));
std::vector<std::shared_ptr<dataset::Tensor>> de_output_tensor_list;
for (auto &tensor : ms_output_tensor_list) {
std::shared_ptr<dataset::Tensor> de_output_tensor;

View File

@ -825,8 +825,8 @@ std::shared_ptr<TensorOp> PadOperation::Build() {
break;
case 2:
pad_left = padding_[0];
pad_top = padding_[1];
pad_right = padding_[0];
pad_top = padding_[0];
pad_right = padding_[1];
pad_bottom = padding_[1];
break;
default:
@ -1096,8 +1096,8 @@ std::shared_ptr<TensorOp> RandomCropOperation::Build() {
break;
case 2:
pad_left = padding_[0];
pad_top = padding_[1];
pad_right = padding_[0];
pad_top = padding_[0];
pad_right = padding_[1];
pad_bottom = padding_[1];
break;
default:
@ -1215,8 +1215,8 @@ std::shared_ptr<TensorOp> RandomCropWithBBoxOperation::Build() {
break;
case 2:
pad_left = padding_[0];
pad_top = padding_[1];
pad_right = padding_[0];
pad_top = padding_[0];
pad_right = padding_[1];
pad_bottom = padding_[1];
break;
default:

View File

@ -141,7 +141,7 @@ std::shared_ptr<JiebaTokenizerOperation> JiebaTokenizer(const std::string &hmm_p
const JiebaMode &mode = JiebaMode::kMix,
bool with_offsets = false);
/// \brief Lookup operator that looks up a word to an id.
/// \brief Look up a word into an id according to the input vocabulary table.
/// \param[in] vocab a Vocab object.
/// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov).
/// If unknown_token is oov, runtime error will be thrown. If unknown_token is {}, which means that not to

View File

@ -200,8 +200,8 @@ std::shared_ptr<NormalizePadOperation> NormalizePad(const std::vector<float> &me
/// \notes Pads the image according to padding parameters
/// \param[in] padding A vector representing the number of pixels to pad the image
/// If vector has one value, it pads all sides of the image with that value.
/// If vector has two values, it pads left and right with the first and
/// top and bottom with the second value.
/// If vector has two values, it pads left and top with the first and
/// right and bottom with the second value.
/// If vector has four values, it pads left, top, right, and bottom with
/// those values respectively.
/// \param[in] fill_value A vector representing the pixel intensity of the borders if the padding_mode is
@ -270,8 +270,12 @@ std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float>
/// \param[in] size A vector representing the output size of the cropped image.
/// If size is a single value, a square crop of size (size, size) is returned.
/// If size has 2 values, it should be (height, width).
/// \param[in] padding A vector with the value of pixels to pad the image. If 4 values are provided,
/// it pads the left, top, right and bottom respectively.
/// \param[in] padding A vector representing the number of pixels to pad the image
/// If vector has one value, it pads all sides of the image with that value.
/// If vector has two values, it pads left and top with the first and
/// right and bottom with the second value.
/// If vector has four values, it pads left, top, right, and bottom with
/// those values respectively.
/// \param[in] pad_if_needed A boolean whether to pad the image if either side is smaller than
/// the given output size.
/// \param[in] fill_value A vector representing the pixel intensity of the borders if the padding_mode is
@ -304,8 +308,12 @@ std::shared_ptr<RandomCropDecodeResizeOperation> RandomCropDecodeResize(
/// \param[in] size A vector representing the output size of the cropped image.
/// If size is a single value, a square crop of size (size, size) is returned.
/// If size has 2 values, it should be (height, width).
/// \param[in] padding A vector with the value of pixels to pad the image. If 4 values are provided,
/// it pads the left, top, right and bottom respectively.
/// \param[in] padding A vector representing the number of pixels to pad the image
/// If vector has one value, it pads all sides of the image with that value.
/// If vector has two values, it pads left and top with the first and
/// right and bottom with the second value.
/// If vector has four values, it pads left, top, right, and bottom with
/// those values respectively.
/// \param[in] pad_if_needed A boolean whether to pad the image if either side is smaller than
/// the given output size.
/// \param[in] fill_value A vector representing the pixel intensity of the borders if the padding_mode is

View File

@ -24,7 +24,7 @@ namespace dataset {
Status DuplicateOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Duplicate: only support one input.");
CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Duplicate: only supports transform one column each time.");
std::shared_ptr<Tensor> out;
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(input[0], &out));
output->push_back(input[0]);

View File

@ -103,8 +103,12 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
RETURN_STATUS_UNEXPECTED("Resize: load image failed.");
}
if (input_cv->Rank() != 3 && input_cv->Rank() != 2) {
RETURN_STATUS_UNEXPECTED("Resize: input is not in shape of <H,W,C> or <H,W>");
RETURN_STATUS_UNEXPECTED("Resize: input tensor is not in shape of <H,W,C> or <H,W>");
}
if (input_cv->shape()[2] != 3 && input_cv->shape()[2] != 1) {
RETURN_STATUS_UNEXPECTED("Resize: channel of input tesnor is not in 1 or 3.");
}
cv::Mat in_image = input_cv->mat();
// resize image too large or too small
if (output_height > in_image.rows * 1000 || output_width > in_image.cols * 1000) {

View File

@ -41,7 +41,7 @@ Status SharpnessOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_pt
/// Get number of channels and image matrix
std::size_t num_of_channels = input_cv->shape()[2];
if (num_of_channels != 1 && num_of_channels != 3) {
RETURN_STATUS_UNEXPECTED("Sharpness: image shape is not <H,W,C>.");
RETURN_STATUS_UNEXPECTED("Sharpness: image channel is not 1 or 3.");
}
/// creating a smoothing filter. 1, 1, 1,

View File

@ -578,7 +578,7 @@ class Dataset:
python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker processes. This
option could be beneficial if the Python operation is computational heavy (default=False).
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
callbacks: (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None).
@ -2328,7 +2328,7 @@ class MapDataset(Dataset):
python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
option could be beneficial if the Python operation is computational heavy (default=False).
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
callbacks: (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None)
Raises:
@ -3024,7 +3024,7 @@ class ImageFolderDataset(MappableDataset):
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
@ -3170,7 +3170,7 @@ class MnistDataset(MappableDataset):
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
@ -3234,7 +3234,7 @@ class MnistDataset(MappableDataset):
class MindDataset(MappableDataset):
"""
A source dataset that reads MindRecord files.
A source dataset for reading and parsing MindRecord dataset.
Args:
dataset_file (Union[str, list[str]]): If dataset_file is a str, it represents for
@ -3811,7 +3811,7 @@ class GeneratorDataset(MappableDataset):
class TFRecordDataset(SourceDataset):
"""
A source dataset that reads and parses datasets stored on disk in TFData format.
A source dataset for reading and parsing datasets stored on disk in TFData format.
Args:
dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a
@ -3843,7 +3843,7 @@ class TFRecordDataset(SourceDataset):
shard_equal_rows (bool, optional): Get equal rows for all shards(default=False). If shard_equal_rows
is false, number of rows of each shard may be not equal.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
Examples:
>>> import mindspore.common.dtype as mstype
@ -3971,7 +3971,7 @@ class TFRecordDataset(SourceDataset):
class ManifestDataset(MappableDataset):
"""
A source dataset that reads images from a manifest file.
A source dataset for reading images from a Manifest file.
The generated dataset has two columns ['image', 'label'].
The shape of the image column is [image_size] if decode flag is False, or [H,W,C]
@ -4027,7 +4027,7 @@ class ManifestDataset(MappableDataset):
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
@ -4122,7 +4122,7 @@ class ManifestDataset(MappableDataset):
class Cifar10Dataset(MappableDataset):
"""
A source dataset that reads cifar10 data.
A source dataset for reading and parsing Cifar10 dataset.
The generated dataset has two columns ['image', 'label'].
The type of the image tensor is uint8. The label is a scalar uint32 tensor.
@ -4188,7 +4188,7 @@ class Cifar10Dataset(MappableDataset):
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
@ -4258,7 +4258,7 @@ class Cifar10Dataset(MappableDataset):
class Cifar100Dataset(MappableDataset):
"""
A source dataset that reads cifar100 data.
A source dataset for reading and parsing Cifar100 dataset.
The generated dataset has three columns ['image', 'coarse_label', 'fine_label'].
The type of the image tensor is uint8. The coarse and fine labels are each a scalar uint32 tensor.
@ -4326,7 +4326,7 @@ class Cifar100Dataset(MappableDataset):
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
@ -4404,7 +4404,7 @@ class RandomDataset(SourceDataset):
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
(default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset will be divided
@ -4667,7 +4667,7 @@ class VOCDataset(MappableDataset):
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
Raises:
RuntimeError: If xml of Annotations is an invalid format.
@ -4860,7 +4860,7 @@ class CocoDataset(MappableDataset):
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
@ -5009,7 +5009,7 @@ class CelebADataset(MappableDataset):
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
Examples:
>>> dataset = ds.CelebADataset(dataset_dir=celeba_dataset_dir, usage='train')
@ -5120,7 +5120,7 @@ class CLUEDataset(SourceDataset):
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
Examples:
>>> clue_dataset_dir = ["/path/to/clue_dataset_file"] # contains 1 or multiple text files
@ -5353,7 +5353,7 @@ class CSVDataset(SourceDataset):
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
Examples:
@ -5459,7 +5459,7 @@ class TextFileDataset(SourceDataset):
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None which means no cache is used).
(default=None, which means no cache is used).
Examples:
>>> # contains 1 or multiple text files

View File

@ -586,7 +586,7 @@ class SubsetSampler(BuiltinSampler):
for i, item in enumerate(indices):
if not isinstance(item, numbers.Number):
raise TypeError("type of weights element should be number, "
raise TypeError("type of indices element should be number, "
"but got w[{}]: {}, type: {}.".format(i, item, type(item)))
self.indices = indices

View File

@ -275,7 +275,7 @@ class JiebaTokenizer(TextTensorOperation):
class Lookup(TextTensorOperation):
"""
Lookup operator that looks up a word to an id.
Look up a word into an id according to the input vocabulary table.
Args:
vocab (Vocab): A vocabulary object.

View File

@ -39,13 +39,14 @@ class OneHot(cde.OneHotOp):
Tensor operation to apply one hot encoding.
Args:
num_classes (int): Number of classes of the label.
num_classes (int): Number of classes of objects in dataset.
It should be larger than the largest label number in the dataset.
Raises:
RuntimeError: feature size is bigger than num_classes.
Examples:
>>> # Assume that dataset has 10 classes, thus the label ranges from 0 to 9
>>> onehot_op = c_transforms.OneHot(num_classes=10)
>>> mnist_dataset = mnist_dataset.map(operations=onehot_op, input_columns=["label"])
"""
@ -114,13 +115,13 @@ class _SliceOption(cde.SliceOption):
Internal class SliceOption to be used with SliceOperation
Args:
_SliceOption(Union[int, list(int), slice, None, Ellipses, bool, _SliceOption]):
_SliceOption(Union[int, list(int), slice, None, Ellipsis, bool, _SliceOption]):
1. :py:obj:`int`: Slice this index only along the dimension. Negative index is supported.
2. :py:obj:`list(int)`: Slice these indices along the dimension. Negative indices are supported.
3. :py:obj:`slice`: Slice the generated indices from the slice object along the dimension.
4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in Python indexing.
5. :py:obj:`Ellipses`: Slice the whole dimension. Similar to `:` in Python indexing.
5. :py:obj:`Ellipsis`: Slice the whole dimension. Similar to `:` in Python indexing.
6. :py:obj:`boolean`: Slice the whole dimension. Similar to `:` in Python indexing.
"""
@ -143,16 +144,16 @@ class Slice(cde.SliceOp):
(Currently only rank-1 tensors are supported).
Args:
*slices(Union[int, list(int), slice, None, Ellipses]):
*slices(Union[int, list(int), slice, None, Ellipsis]):
Maximum `n` number of arguments to slice a tensor of rank `n`.
One object in slices can be one of:
1. :py:obj:`int`: Slice this index only along the first dimension. Negative index is supported.
2. :py:obj:`list(int)`: Slice these indices along the first dimension. Negative indices are supported.
3. :py:obj:`slice`: Slice the generated indices from the slice object along the first dimension.
Similar to `start:stop:step`.
Similar to start:stop:step.
4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in Python indexing.
5. :py:obj:`Ellipses`: Slice the whole dimension. Similar to `:` in Python indexing.
5. :py:obj:`Ellipsis`: Slice the whole dimension. Similar to `:` in Python indexing.
Examples:
>>> # Data before
@ -232,7 +233,7 @@ class Mask(cde.MaskOp):
class PadEnd(cde.PadEndOp):
"""
Pad input tensor according to `pad_shape`, need to have same rank.
Pad input tensor according to pad_shape, need to have same rank.
Args:
pad_shape (list(int)): List of integers representing the shape needed. Dimensions that set to `None` will
@ -295,7 +296,7 @@ class Concatenate(cde.ConcatenateOp):
class Duplicate(cde.DuplicateOp):
"""
Duplicate the input tensor to a new output tensor. The input tensor is carried over to the output list.
Duplicate the input tensor to output, only support transform one column each time.
Examples:
>>> # Data before
@ -405,7 +406,7 @@ class RandomApply():
class RandomChoice():
"""
Randomly selects one transform from a list of transforms to perform operation.
Randomly select one transform from a list of transforms to perform operation.
Args:
transforms (list): List of transformations to be chosen from to apply.

View File

@ -26,11 +26,13 @@ class OneHotOp:
Apply one hot encoding transformation to the input label, make label be more smoothing and continuous.
Args:
num_classes (int): Number of classes of objects in dataset. Value must be larger than 0.
num_classes (int): Number of classes of objects in dataset.
It should be larger than the largest label number in the dataset.
smoothing_rate (float, optional): Adjustable hyperparameter for label smoothing level.
(Default=0.0 means no smoothing is applied.)
Examples:
>>> # Assume that dataset has 10 classes, thus the label ranges from 0 to 9
>>> transforms_list = [py_transforms.OneHotOp(num_classes=10, smoothing_rate=0.1)]
>>> transform = py_transforms.Compose(transforms_list)
>>> mnist_dataset = mnist_dataset(input_columns=["label"], operations=transform)
@ -83,14 +85,14 @@ class Compose:
>>>
>>> # Compose is also be invoked implicitly, by just passing in a list of ops
>>> # the above example then becomes:
>>> transform_list = [py_vision.Decode(),
... py_vision.RandomHorizontalFlip(0.5),
... py_vision.ToTensor(),
... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
... py_vision.RandomErasing()]
>>> transforms_list = [py_vision.Decode(),
... py_vision.RandomHorizontalFlip(0.5),
... py_vision.ToTensor(),
... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
... py_vision.RandomErasing()]
>>>
>>> # apply the transform to the dataset through dataset.map()
>>> image_folder_dataset_1 = image_folder_dataset_1.map(operations=transform_list, input_columns=["image"])
>>> image_folder_dataset_1 = image_folder_dataset_1.map(operations=transforms_list, input_columns=["image"])
>>>
>>> # Certain C++ and Python ops can be combined, but not all of them
>>> # An example of combined operations
@ -134,9 +136,9 @@ class RandomApply:
Examples:
>>> from mindspore.dataset.transforms.py_transforms import Compose
>>> transform_list = [py_vision.RandomHorizontalFlip(0.5),
... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
... py_vision.RandomErasing()]
>>> transforms_list = [py_vision.RandomHorizontalFlip(0.5),
... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
... py_vision.RandomErasing()]
>>> transforms = Compose([py_vision.Decode(),
... py_transforms.RandomApply(transforms_list, prob=0.6),
... py_vision.ToTensor()])
@ -170,11 +172,11 @@ class RandomChoice:
Examples:
>>> from mindspore.dataset.transforms.py_transforms import Compose, RandomChoice
>>> transform_list = [py_vision.RandomHorizontalFlip(0.5),
... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
... py_vision.RandomErasing()]
>>> transforms_list = [py_vision.RandomHorizontalFlip(0.5),
... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
... py_vision.RandomErasing()]
>>> transforms = Compose([py_vision.Decode(),
... py_transforms.RandomChoice(transform_list),
... py_transforms.RandomChoice(transforms_list),
... py_vision.ToTensor()])
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms, input_columns=["image"])
"""
@ -205,9 +207,9 @@ class RandomOrder:
Examples:
>>> from mindspore.dataset.transforms.py_transforms import Compose
>>> transform_list = [py_vision.RandomHorizontalFlip(0.5),
... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
... py_vision.RandomErasing()]
>>> transforms_list = [py_vision.RandomHorizontalFlip(0.5),
... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
... py_vision.RandomErasing()]
>>> transforms = Compose([py_vision.Decode(),
... py_transforms.RandomOrder(transforms_list),
... py_vision.ToTensor()])

View File

@ -108,8 +108,8 @@ def parse_padding(padding):
if isinstance(padding, numbers.Number):
padding = [padding] * 4
if len(padding) == 2:
left = right = padding[0]
top = bottom = padding[1]
left = top = padding[0]
right = bottom = padding[1]
padding = (left, top, right, bottom,)
if isinstance(padding, list):
padding = tuple(padding)
@ -438,8 +438,8 @@ class Pad(ImageTensorOperation):
Args:
padding (Union[int, sequence]): The number of pixels to pad the image.
If a single number is provided, it pads all borders with this value.
If a tuple or list of 2 values are provided, it pads left and right
with the first value and top and bottom with the second value.
If a tuple or list of 2 values are provided, it pads the (left and top)
with the first value and (right and bottom) with the second value.
If 4 values are provided as a list or tuple,
it pads the left, top, right and bottom respectively.
fill_value (Union[int, tuple], optional): The pixel intensity of the borders, only valid for
@ -674,8 +674,8 @@ class RandomCrop(ImageTensorOperation):
padding (Union[int, sequence], optional): The number of pixels to pad the image (default=None).
If padding is not None, pad image firstly with padding values.
If a single number is provided, pad all borders with this value.
If a tuple or list of 2 values are provided, it pads left and right
with the first value and top and bottom with the second value.
If a tuple or list of 2 values are provided, pad the (left and top)
with the first value and (right and bottom) with the second value.
If 4 values are provided as a list or tuple,
pad the left, top, right and bottom respectively.
pad_if_needed (bool, optional): Pad the image if either side is smaller than
@ -790,8 +790,8 @@ class RandomCropWithBBox(ImageTensorOperation):
padding (Union[int, sequence], optional): The number of pixels to pad the image (default=None).
If padding is not None, first pad image with padding values.
If a single number is provided, pad all borders with this value.
If a tuple or list of 2 values are provided, it pads left and right
with the first value and top and bottom with the second value.
If a tuple or list of 2 values are provided, pad the (left and top)
with the first value and (right and bottom) with the second value.
If 4 values are provided as a list or tuple, pad the left, top, right and bottom respectively.
pad_if_needed (bool, optional): Pad the image if either side is smaller than
the given output size (default=False).
@ -845,7 +845,7 @@ class RandomCropWithBBox(ImageTensorOperation):
class RandomHorizontalFlip(ImageTensorOperation):
"""
Flip the input image horizontally, randomly with a given probability.
Randomly flip the input image horizontally with a given probability.
Args:
prob (float, optional): Probability of the image being flipped (default=0.5).
@ -1202,12 +1202,12 @@ class RandomSharpness(ImageTensorOperation):
class RandomSolarize(ImageTensorOperation):
"""
Invert all pixel values above a threshold.
Invert all pixel values with given range.
Args:
threshold (tuple, optional): Range of random solarize threshold. Threshold values should always be
in the range (0, 255), include at least one integer value in the given range and
be in (min, max) format. If min=max, then it is a single fixed magnitude operation (default=(0, 255)).
in the range (0, 255), include at least one integer value in the given range and be in
(min, max) format. If min=max, then invert all pixel values above min(max) (default=(0, 255)).
Examples:
>>> transforms_list = [c_vision.Decode(), c_vision.RandomSolarize(threshold=(10,100))]
@ -1225,7 +1225,7 @@ class RandomSolarize(ImageTensorOperation):
class RandomVerticalFlip(ImageTensorOperation):
"""
Flip the input image vertically, randomly with a given probability.
Randomly flip the input image vertically with a given probability.
Args:
prob (float, optional): Probability of the image being flipped (default=0.5).

View File

@ -1357,7 +1357,8 @@ class RandomColor:
class RandomSharpness:
"""
Adjust the sharpness of the input PIL image by a random degree.
Adjust the sharpness of the input PIL image by a fixed or random degree. Degree of 0.0 gives a blurred image,
degree of 1.0 gives the original image, and degree of 2.0 gives a sharpened image.
Args:
degrees (sequence): Range of random sharpness adjustment degrees.

View File

@ -499,7 +499,8 @@ def check_cutout(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[length, num_patches], _ = parse_user_args(method, *args, **kwargs)
type_check(length, (int,), "length")
type_check(num_patches, (int,), "num_patches")
check_value(length, (1, FLOAT_MAX_INTEGER))
check_value(num_patches, (1, FLOAT_MAX_INTEGER))
@ -613,7 +614,10 @@ def check_uniform_augment_cpp(method):
parsed_transforms.append(op.parse())
else:
parsed_transforms.append(op)
type_check_list(parsed_transforms, (TensorOp, TensorOperation), "transforms")
type_check(parsed_transforms, (list, tuple,), "transforms")
for index, arg in enumerate(parsed_transforms):
if not isinstance(arg, (TensorOp, TensorOperation)):
raise TypeError("Type of Transforms[{0}] must be c_transform, but got {1}".format(index, type(arg)))
return method(self, *args, **kwargs)

View File

@ -72,6 +72,26 @@ def test_pad_op():
assert mse < 0.01
def test_pad_op2():
"""
Test Pad op2
"""
logger.info("test padding parameter with size 2")
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = c_vision.Decode()
resize_op = c_vision.Resize([90, 90])
pad_op = c_vision.Pad((100, 9,))
ctrans = [decode_op, resize_op, pad_op]
data1 = data1.map(operations=ctrans, input_columns=["image"])
for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info(data["image"].shape)
# It pads left, top with 100 and right, bottom with 9,
# so the final size of image is 90 + 100 + 9 = 199
assert data["image"].shape[0] == 199
assert data["image"].shape[1] == 199
def test_pad_grayscale():
"""

View File

@ -145,6 +145,28 @@ def test_slice_multiple_rows():
np.testing.assert_array_equal(exp_d, d['col'])
def test_slice_none_and_ellipsis():
"""
Test passing None and Ellipsis to Slice
"""
dataset = [[1], [3, 4, 5], [1, 2], [1, 2, 3, 4, 5, 6, 7]]
exp_dataset = [[1], [3, 4, 5], [1, 2], [1, 2, 3, 4, 5, 6, 7]]
def gen():
for row in dataset:
yield (np.array(row),)
data = ds.GeneratorDataset(gen, column_names=["col"])
data = data.map(operations=ops.Slice(None))
for (d, exp_d) in zip(data.create_dict_iterator(output_numpy=True), exp_dataset):
np.testing.assert_array_equal(exp_d, d['col'])
data = ds.GeneratorDataset(gen, column_names=["col"])
data = data.map(operations=ops.Slice(Ellipsis))
for (d, exp_d) in zip(data.create_dict_iterator(output_numpy=True), exp_dataset):
np.testing.assert_array_equal(exp_d, d['col'])
def test_slice_obj_neg():
slice_compare([1, 2, 3, 4, 5], slice(-1, -5, -1), [5, 4, 3, 2])
slice_compare([1, 2, 3, 4, 5], slice(-1), [1, 2, 3, 4])

View File

@ -186,10 +186,7 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2):
C.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Argument transforms[5] with value" \
" <mindspore.dataset.vision.py_transforms.Invert" in str(e.value)
assert "is not of type (<class 'mindspore._c_dataengine.TensorOp'>,"\
" <class 'mindspore._c_dataengine.TensorOperation'>)" in str(e.value)
assert "Type of Transforms[5] must be c_transform" in str(e.value)
def test_cpp_uniform_augment_exception_large_numops(num_ops=6):