!11853 Add call for decoupled image and text ops

From: @alexyuyue
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-01 19:47:00 +08:00 committed by Gitee
commit 320ea51308
14 changed files with 1249 additions and 1138 deletions

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,10 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
#include "minddata/dataset/include/execute.h" #include "minddata/dataset/core/tensor_row.h"
#ifdef ENABLE_ANDROID #ifdef ENABLE_ANDROID
#include "minddata/dataset/include/de_tensor.h" #include "minddata/dataset/include/de_tensor.h"
#endif #endif
#include "minddata/dataset/include/execute.h"
#include "minddata/dataset/include/tensor.h" #include "minddata/dataset/include/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h" #include "minddata/dataset/kernels/tensor_op.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
@ -84,5 +85,25 @@ std::shared_ptr<dataset::Tensor> Execute::operator()(std::shared_ptr<dataset::Te
return de_output; return de_output;
} }
Status Execute::operator()(const std::vector<std::shared_ptr<Tensor>> &input_tensor_list,
std::vector<std::shared_ptr<Tensor>> *output_tensor_list) {
CHECK_FAIL_RETURN_UNEXPECTED(op_ != nullptr, "Input TensorOperation is not valid");
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid");
TensorRow input, output;
std::copy(input_tensor_list.begin(), input_tensor_list.end(), std::back_inserter(input));
CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "Input Tensor is not valid");
std::shared_ptr<TensorOp> transform = op_->Build();
Status rc = transform->Compute(input, &output);
if (rc.IsError()) {
// execution failed
RETURN_STATUS_UNEXPECTED("Operation execution failed : " + rc.ToString());
}
std::copy(output.begin(), output.end(), std::back_inserter(*output_tensor_list));
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,14 +28,26 @@ PYBIND_REGISTER(Execute, 0, ([](const py::module *m) {
auto execute = std::make_shared<Execute>(toTensorOperation(operation)); auto execute = std::make_shared<Execute>(toTensorOperation(operation));
return execute; return execute;
})) }))
.def("__call__", [](Execute &self, std::shared_ptr<Tensor> in) { .def("__call__",
std::shared_ptr<Tensor> out = self(in); [](Execute &self, std::shared_ptr<Tensor> in) {
if (out == nullptr) { std::shared_ptr<Tensor> out = self(in);
if (out == nullptr) {
THROW_IF_ERROR([]() {
RETURN_STATUS_UNEXPECTED(
"Failed to execute op in eager mode, please check ERROR log above.");
}());
}
return out;
})
.def("__call__", [](Execute &self, const std::vector<std::shared_ptr<Tensor>> &input_tensor_list) {
std::vector<std::shared_ptr<Tensor>> output_tensor_list;
THROW_IF_ERROR(self(input_tensor_list, &output_tensor_list));
if (output_tensor_list.empty()) {
THROW_IF_ERROR([]() { THROW_IF_ERROR([]() {
RETURN_STATUS_UNEXPECTED("Failed to execute op in eager mode, please check ERROR log above."); RETURN_STATUS_UNEXPECTED("Failed to execute op in eager mode, please check ERROR log above.");
}()); }());
} }
return out; return output_tensor_list;
}); });
})); }));
} // namespace dataset } // namespace dataset

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -43,16 +43,23 @@ class Execute {
#ifdef ENABLE_ANDROID #ifdef ENABLE_ANDROID
/// \brief callable function to execute the TensorOperation in eager mode /// \brief callable function to execute the TensorOperation in eager mode
/// \param[inout] input - the tensor to be transformed /// \param[in] input - the tensor to be transformed
/// \return - the output tensor, nullptr if Compute fails /// \return - the output tensor, nullptr if Compute fails
std::shared_ptr<tensor::MSTensor> operator()(std::shared_ptr<tensor::MSTensor> input); std::shared_ptr<tensor::MSTensor> operator()(std::shared_ptr<tensor::MSTensor> input);
#endif #endif
/// \brief callable function to execute the TensorOperation in eager mode /// \brief callable function to execute the TensorOperation in eager mode
/// \param[inout] input - the tensor to be transformed /// \param[in] input - the tensor to be transformed
/// \return - the output tensor, nullptr if Compute fails /// \return - the output tensor, nullptr if Compute fails
std::shared_ptr<dataset::Tensor> operator()(std::shared_ptr<dataset::Tensor> input); std::shared_ptr<dataset::Tensor> operator()(std::shared_ptr<dataset::Tensor> input);
/// \brief callable function to execute the TensorOperation in eager mode
/// \param[in] input_tensor_list - the tensor to be transformed
/// \param[out] out - the result tensor after transform
/// \return - Status
Status operator()(const std::vector<std::shared_ptr<Tensor>> &input_tensor_list,
std::vector<std::shared_ptr<Tensor>> *out);
private: private:
std::shared_ptr<TensorOperation> op_; std::shared_ptr<TensorOperation> op_;
}; };

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -59,112 +59,37 @@ from .validators import check_lookup, check_jieba_add_dict, \
check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow
from ..core.datatypes import mstype_to_detype from ..core.datatypes import mstype_to_detype
from ..core.validator_helpers import replace_none from ..core.validator_helpers import replace_none
from ..transforms.c_transforms import TensorOperation
class TextTensorOperation:
def parse(self):
raise NotImplementedError("TextTensorOperation has to implement parse method.")
class Lookup(TextTensorOperation): class TextTensorOperation(TensorOperation):
""" """
Lookup operator that looks up a word to an id. Base class of Text Tensor Ops
Args:
vocab (Vocab): A vocabulary object.
unknown_token (str, optional): Word used for lookup if the word being looked up is out-of-vocabulary (OOV).
If unknown_token is OOV, a runtime error will be thrown (default=None).
data_type (mindspore.dtype, optional): mindspore.dtype that lookup maps string to (default=mstype.int32)
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # Load vocabulary from list
>>> vocab = text.Vocab.from_list(['', '', '', '', ''])
>>> # Use Lookup operator to map tokens to ids
>>> lookup = text.Lookup(vocab)
>>> data1 = data1.map(operations=[lookup])
""" """
def __call__(self, input_tensor):
@check_lookup if not isinstance(input_tensor, list):
def __init__(self, vocab, unknown_token=None, data_type=mstype.int32): input_list = [input_tensor]
self.vocab = vocab else:
self.unknown_token = replace_none(unknown_token, '') input_list = input_tensor
self.data_type = data_type tensor_list = []
for tensor in input_list:
if not isinstance(tensor, str):
raise TypeError("Input should be string or list of strings, got {}.".format(type(tensor)))
tensor_list.append(cde.Tensor(np.asarray(tensor)))
callable_op = cde.Execute(self.parse())
output_list = callable_op(tensor_list)
for i, element in enumerate(output_list):
arr = element.as_array()
if arr.dtype.char == 'S':
output_list[i] = to_str(arr)
else:
output_list[i] = arr
if not isinstance(input_tensor, list) and len(output_list) == 1:
output_list = output_list[0]
return output_list
def parse(self): def parse(self):
return cde.LookupOperation(self.vocab, self.unknown_token, str(mstype_to_detype(self.data_type))) raise NotImplementedError("TextTensorOperation has to implement parse() method.")
class SlidingWindow(TextTensorOperation):
"""
TensorOp to construct a tensor from data (only 1-D for now), where each element in the dimension axis
is a slice of data starting at the corresponding position, with a specified width.
Args:
width (int): The width of the window. It must be an integer and greater than zero.
axis (int, optional): The axis along which the sliding window is computed (default=0).
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # Data before
>>> # | col1 |
>>> # +-------------+
>>> # | [1,2,3,4,5] |
>>> # +-------------+
>>> data1 = data1.map(operations=text.SlidingWindow(3, 0))
>>> # Data after
>>> # | col1 |
>>> # +-------------+
>>> # | [[1,2,3], |
>>> # | [2,3,4], |
>>> # | [3,4,5]] |
>>> # +--------------+
"""
@check_slidingwindow
def __init__(self, width, axis=0):
self.width = width
self.axis = axis
def parse(self):
return cde.SlidingWindowOperation(self.width, self.axis)
class Ngram(TextTensorOperation):
"""
TensorOp to generate n-gram from a 1-D string Tensor.
Refer to https://en.wikipedia.org/wiki/N-gram#Examples for an overview of what n-gram is and how it works.
Args:
n (list[int]): n in n-gram, n >= 1. n is a list of positive integers. For example, if n=[4,3], then the result
would be a 4-gram followed by a 3-gram in the same tensor. If the number of words is not enough to make up
for a n-gram, an empty string will be returned. For example, 3 grams on ["mindspore","best"] will result in
an empty string produced.
left_pad (tuple, optional): ("pad_token", pad_width). Padding performed on left side of the sequence. pad_width
will be capped at n-1. left_pad=("_",2) would pad left side of the sequence with "__" (default=None).
right_pad (tuple, optional): ("pad_token", pad_width). Padding performed on right side of the sequence.
pad_width will be capped at n-1. right_pad=("-":2) would pad right side of the sequence with "--"
(default=None).
separator (str, optional): symbol used to join strings together. For example. if 2-gram is
["mindspore", "amazing"] with separator="-", the result would be ["mindspore-amazing"]
(default=None, which means whitespace is used).
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> data1 = data1.map(operations=text.Ngram(3, separator=" "))
"""
@check_ngram
def __init__(self, n, left_pad=("", 0), right_pad=("", 0), separator=" "):
self.ngrams = n
self.left_pad = left_pad
self.right_pad = right_pad
self.separator = separator
def parse(self):
return cde.NgramOperation(self.ngrams, self.left_pad, self.right_pad, self.separator)
DE_C_INTER_JIEBA_MODE = { DE_C_INTER_JIEBA_MODE = {
@ -174,6 +99,18 @@ DE_C_INTER_JIEBA_MODE = {
} }
DE_C_INTER_SENTENCEPIECE_LOADTYPE = {
SPieceTokenizerLoadType.FILE: cde.SPieceTokenizerLoadType.DE_SPIECE_TOKENIZER_LOAD_KFILE,
SPieceTokenizerLoadType.MODEL: cde.SPieceTokenizerLoadType.DE_SPIECE_TOKENIZER_LOAD_KMODEL
}
DE_C_INTER_SENTENCEPIECE_OUTTYPE = {
SPieceTokenizerOutType.STRING: cde.SPieceTokenizerOutType.DE_SPIECE_TOKENIZER_OUTTYPE_KString,
SPieceTokenizerOutType.INT: cde.SPieceTokenizerOutType.DE_SPIECE_TOKENIZER_OUTTYPE_KINT
}
class JiebaTokenizer(TextTensorOperation): class JiebaTokenizer(TextTensorOperation):
""" """
Tokenize Chinese string into words based on dictionary. Tokenize Chinese string into words based on dictionary.
@ -335,6 +272,201 @@ class JiebaTokenizer(TextTensorOperation):
" jieba mode file {} is not exist.".format(model_path)) " jieba mode file {} is not exist.".format(model_path))
class Lookup(TextTensorOperation):
"""
Lookup operator that looks up a word to an id.
Args:
vocab (Vocab): A vocabulary object.
unknown_token (str, optional): Word used for lookup if the word being looked up is out-of-vocabulary (OOV).
If unknown_token is OOV, a runtime error will be thrown (default=None).
data_type (mindspore.dtype, optional): mindspore.dtype that lookup maps string to (default=mstype.int32)
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # Load vocabulary from list
>>> vocab = text.Vocab.from_list(['', '', '', '', ''])
>>> # Use Lookup operator to map tokens to ids
>>> lookup = text.Lookup(vocab)
>>> data1 = data1.map(operations=[lookup])
"""
@check_lookup
def __init__(self, vocab, unknown_token=None, data_type=mstype.int32):
self.vocab = vocab
self.unknown_token = replace_none(unknown_token, '')
self.data_type = data_type
def parse(self):
return cde.LookupOperation(self.vocab, self.unknown_token, str(mstype_to_detype(self.data_type)))
class Ngram(TextTensorOperation):
"""
TensorOp to generate n-gram from a 1-D string Tensor.
Refer to https://en.wikipedia.org/wiki/N-gram#Examples for an overview of what n-gram is and how it works.
Args:
n (list[int]): n in n-gram, n >= 1. n is a list of positive integers. For example, if n=[4,3], then the result
would be a 4-gram followed by a 3-gram in the same tensor. If the number of words is not enough to make up
for a n-gram, an empty string will be returned. For example, 3 grams on ["mindspore","best"] will result in
an empty string produced.
left_pad (tuple, optional): ("pad_token", pad_width). Padding performed on left side of the sequence. pad_width
will be capped at n-1. left_pad=("_",2) would pad left side of the sequence with "__" (default=None).
right_pad (tuple, optional): ("pad_token", pad_width). Padding performed on right side of the sequence.
pad_width will be capped at n-1. right_pad=("-":2) would pad right side of the sequence with "--"
(default=None).
separator (str, optional): symbol used to join strings together. For example. if 2-gram is
["mindspore", "amazing"] with separator="-", the result would be ["mindspore-amazing"]
(default=None, which means whitespace is used).
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> data1 = data1.map(operations=text.Ngram(3, separator=" "))
"""
@check_ngram
def __init__(self, n, left_pad=("", 0), right_pad=("", 0), separator=" "):
self.ngrams = n
self.left_pad = left_pad
self.right_pad = right_pad
self.separator = separator
def parse(self):
return cde.NgramOperation(self.ngrams, self.left_pad, self.right_pad, self.separator)
class SentencePieceTokenizer(TextTensorOperation):
"""
Tokenize scalar token or 1-D tokens to tokens by sentencepiece.
Args:
mode (Union[str, SentencePieceVocab]): If the input parameter is a file, then it is of type string.
If the input parameter is a SentencePieceVocab object, then it is of type SentencePieceVocab.
out_type (Union[str, int]): The type of output.
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.UNIGRAM, {})
>>> tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING)
>>> data1 = data1.map(operations=tokenizer)
"""
def __init__(self, mode, out_type):
self.mode = mode
self.out_type = out_type
def parse(self):
return cde.SentencePieceTokenizerOperation(self.mode, DE_C_INTER_SENTENCEPIECE_OUTTYPE[self.out_type])
class SlidingWindow(TextTensorOperation):
"""
TensorOp to construct a tensor from data (only 1-D for now), where each element in the dimension axis
is a slice of data starting at the corresponding position, with a specified width.
Args:
width (int): The width of the window. It must be an integer and greater than zero.
axis (int, optional): The axis along which the sliding window is computed (default=0).
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # Data before
>>> # | col1 |
>>> # +-------------+
>>> # | [1,2,3,4,5] |
>>> # +-------------+
>>> data1 = data1.map(operations=text.SlidingWindow(3, 0))
>>> # Data after
>>> # | col1 |
>>> # +-------------+
>>> # | [[1,2,3], |
>>> # | [2,3,4], |
>>> # | [3,4,5]] |
>>> # +--------------+
"""
@check_slidingwindow
def __init__(self, width, axis=0):
self.width = width
self.axis = axis
def parse(self):
return cde.SlidingWindowOperation(self.width, self.axis)
class ToNumber(TextTensorOperation):
"""
Tensor operation to convert every element of a string tensor to a number.
Strings are casted according to the rules specified in the following links:
https://en.cppreference.com/w/cpp/string/basic_string/stof,
https://en.cppreference.com/w/cpp/string/basic_string/stoul,
except that any strings which represent negative numbers cannot be cast to an
unsigned integer type.
Args:
data_type (mindspore.dtype): mindspore.dtype to be casted to. Must be
a numeric type.
Raises:
RuntimeError: If strings are invalid to cast, or are out of range after being casted.
Examples:
>>> import mindspore.dataset.text as text
>>> import mindspore.common.dtype as mstype
>>>
>>> to_number_op = text.ToNumber(mstype.int8)
>>> data1 = data1.map(operations=to_number_op)
"""
@check_to_number
def __init__(self, data_type):
data_type = mstype_to_detype(data_type)
self.data_type = str(data_type)
def parse(self):
return cde.ToNumberOperation(self.data_type)
class TruncateSequencePair(TextTensorOperation):
"""
Truncate a pair of rank-1 tensors such that the total length is less than max_length.
This operation takes two input tensors and returns two output Tensors.
Args:
max_length (int): Maximum length required.
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # Data before
>>> # | col1 | col2 |
>>> # +---------+---------|
>>> # | [1,2,3] | [4,5] |
>>> # +---------+---------+
>>> data1 = data1.map(operations=text.TruncateSequencePair(4))
>>> # Data after
>>> # | col1 | col2 |
>>> # +---------+---------+
>>> # | [1,2] | [4,5] |
>>> # +---------+---------+
"""
@check_pair_truncate
def __init__(self, max_length):
self.max_length = max_length
def parse(self):
return cde.TruncateSequencePairOperation(self.max_length)
class UnicodeCharTokenizer(TextTensorOperation): class UnicodeCharTokenizer(TextTensorOperation):
""" """
Tokenize a scalar tensor of UTF-8 string to Unicode characters. Tokenize a scalar tensor of UTF-8 string to Unicode characters.
@ -405,131 +537,31 @@ class WordpieceTokenizer(cde.WordpieceTokenizerOp):
self.unknown_token, self.with_offsets) self.unknown_token, self.with_offsets)
DE_C_INTER_SENTENCEPIECE_LOADTYPE = { class PythonTokenizer:
SPieceTokenizerLoadType.FILE: cde.SPieceTokenizerLoadType.DE_SPIECE_TOKENIZER_LOAD_KFILE,
SPieceTokenizerLoadType.MODEL: cde.SPieceTokenizerLoadType.DE_SPIECE_TOKENIZER_LOAD_KMODEL
}
DE_C_INTER_SENTENCEPIECE_OUTTYPE = {
SPieceTokenizerOutType.STRING: cde.SPieceTokenizerOutType.DE_SPIECE_TOKENIZER_OUTTYPE_KString,
SPieceTokenizerOutType.INT: cde.SPieceTokenizerOutType.DE_SPIECE_TOKENIZER_OUTTYPE_KINT
}
class SentencePieceTokenizer(TextTensorOperation):
""" """
Tokenize scalar token or 1-D tokens to tokens by sentencepiece. Callable class to be used for user-defined string tokenizer.
Args: Args:
mode (Union[str, SentencePieceVocab]): If the input parameter is a file, then it is of type string. tokenizer (Callable): Python function that takes a `str` and returns a list of `str` as tokens.
If the input parameter is a SentencePieceVocab object, then it is of type SentencePieceVocab.
out_type (Union[str, int]): The type of output.
Examples: Examples:
>>> import mindspore.dataset.text as text >>> import mindspore.dataset.text as text
>>> >>>
>>> vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.UNIGRAM, {}) >>> def my_tokenizer(line):
>>> tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING) >>> return line.split()
>>> data1 = data1.map(operations=tokenizer) >>> data1 = data1.map(operations=text.PythonTokenizer(my_tokenizer))
""" """
def __init__(self, mode, out_type): @check_python_tokenizer
self.mode = mode def __init__(self, tokenizer):
self.out_type = out_type self.tokenizer = np.vectorize(lambda x: np.array(tokenizer(x), dtype='U'), signature='()->(n)')
def parse(self):
return cde.SentencePieceTokenizerOperation(self.mode, DE_C_INTER_SENTENCEPIECE_OUTTYPE[self.out_type])
def __call__(self, in_array):
in_array = to_str(in_array)
tokens = self.tokenizer(in_array)
return tokens
if platform.system().lower() != 'windows': if platform.system().lower() != 'windows':
class WhitespaceTokenizer(TextTensorOperation):
"""
Tokenize a scalar tensor of UTF-8 string on ICU4C defined whitespaces, such as: ' ', '\\\\t', '\\\\r', '\\\\n'.
Note:
WhitespaceTokenizer is not supported on Windows platform yet.
Args:
with_offsets (bool, optional): If or not output offsets of tokens (default=False).
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.WhitespaceTokenizer()
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str],
>>> # ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.WhitespaceTokenizer(True)
>>> data2 = data2.map(operations=tokenizer_op, input_columns=["text"],
>>> output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
"""
@check_with_offsets
def __init__(self, with_offsets=False):
self.with_offsets = with_offsets
def parse(self):
return cde.WhitespaceTokenizerOperation(self.with_offsets)
class UnicodeScriptTokenizer(TextTensorOperation):
"""
Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.
Note:
UnicodeScriptTokenizer is not supported on Windows platform yet.
Args:
keep_whitespace (bool, optional): If or not emit whitespace tokens (default=False).
with_offsets (bool, optional): If or not output offsets of tokens (default=False).
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.UnicodeScriptTokenizerOp(keep_whitespace=True, with_offsets=False)
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str],
>>> # ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.UnicodeScriptTokenizerOp(keep_whitespace=True, with_offsets=True)
>>> data2 = data2.map(operations=tokenizer_op, input_columns=["text"],
>>> output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
"""
@check_unicode_script_tokenizer
def __init__(self, keep_whitespace=False, with_offsets=False):
keep_whitespace = replace_none(keep_whitespace, False)
with_offsets = replace_none(with_offsets, False)
self.keep_whitespace = keep_whitespace
self.with_offsets = with_offsets
def parse(self):
return cde.UnicodeScriptTokenizerOperation(self.keep_whitespace, self.with_offsets)
class CaseFold(TextTensorOperation):
"""
Apply case fold operation on UTF-8 string tensor.
Note:
CaseFold is not supported on Windows platform yet.
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> case_op = text.CaseFold()
>>> data1 = data1.map(operations=case_op)
"""
def parse(self):
return cde.CaseFoldOperation()
DE_C_INTER_NORMALIZE_FORM = { DE_C_INTER_NORMALIZE_FORM = {
NormalizeForm.NONE: cde.NormalizeForm.DE_NORMALIZE_NONE, NormalizeForm.NONE: cde.NormalizeForm.DE_NORMALIZE_NONE,
NormalizeForm.NFC: cde.NormalizeForm.DE_NORMALIZE_NFC, NormalizeForm.NFC: cde.NormalizeForm.DE_NORMALIZE_NFC,
@ -539,118 +571,6 @@ if platform.system().lower() != 'windows':
} }
class NormalizeUTF8(TextTensorOperation):
"""
Apply normalize operation on UTF-8 string tensor.
Note:
NormalizeUTF8 is not supported on Windows platform yet.
Args:
normalize_form (NormalizeForm, optional): Valid values can be any of [NormalizeForm.NONE,
NormalizeForm.NFC, NormalizeForm.NFKC, NormalizeForm.NFD,
NormalizeForm.NFKD](default=NormalizeForm.NFKC).
See http://unicode.org/reports/tr15/ for details.
- NormalizeForm.NONE, do nothing for input string tensor.
- NormalizeForm.NFC, normalize with Normalization Form C.
- NormalizeForm.NFKC, normalize with Normalization Form KC.
- NormalizeForm.NFD, normalize with Normalization Form D.
- NormalizeForm.NFKD, normalize with Normalization Form KD.
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> normalize_op = text.NormalizeUTF8(normalize_form=NormalizeForm.NFC)
>>> data1 = data1.map(operations=normalize_op)
"""
def __init__(self, normalize_form=NormalizeForm.NFKC):
if not isinstance(normalize_form, NormalizeForm):
raise TypeError("Wrong input type for normalization_form, should be enum of 'NormalizeForm'.")
normalize_form = replace_none(normalize_form, NormalizeForm.NFKC)
self.normalize_form = DE_C_INTER_NORMALIZE_FORM[normalize_form]
def parse(self):
return cde.NormalizeUTF8Operation(self.normalize_form)
class RegexReplace(TextTensorOperation):
"""
Replace UTF-8 string tensor with 'replace' according to regular expression 'pattern'.
See http://userguide.icu-project.org/strings/regexp for support regex pattern.
Note:
RegexReplace is not supported on Windows platform yet.
Args:
pattern (str): the regex expression patterns.
replace (str): the string to replace matched element.
replace_all (bool, optional): If False, only replace first matched element;
if True, replace all matched elements (default=True).
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> pattern = 'Canada'
>>> replace = 'China'
>>> replace_op = text.RegexReplace(pattern, replace)
>>> data1 = data1.map(operations=replace_op)
"""
def __init__(self, pattern, replace, replace_all=True):
self.pattern = pattern
self.replace = replace
self.replace_all = replace_all
def parse(self):
return cde.RegexReplaceOperation(self.pattern, self.replace, self.replace_all)
class RegexTokenizer(TextTensorOperation):
"""
Tokenize a scalar tensor of UTF-8 string by regex expression pattern.
See http://userguide.icu-project.org/strings/regexp for support regex pattern.
Note:
RegexTokenizer is not supported on Windows platform yet.
Args:
delim_pattern (str): The pattern of regex delimiters.
The original string will be split by matched elements.
keep_delim_pattern (str, optional): The string matched by 'delim_pattern' can be kept as a token
if it can be matched by 'keep_delim_pattern'. The default value is an empty str ('')
which means that delimiters will not be kept as an output token (default='').
with_offsets (bool, optional): If or not output offsets of tokens (default=False).
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.RegexTokenizer(delim_pattern, keep_delim_pattern, with_offsets=False)
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str],
>>> # ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.RegexTokenizer(delim_pattern, keep_delim_pattern, with_offsets=True)
>>> data2 = data2.map(operations=tokenizer_op, input_columns=["text"],
>>> output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
"""
@check_regex_tokenizer
def __init__(self, delim_pattern, keep_delim_pattern='', with_offsets=False):
self.delim_pattern = delim_pattern
self.keep_delim_pattern = keep_delim_pattern
self.with_offsets = with_offsets
def parse(self):
return cde.RegexTokenizerOperation(self.delim_pattern, self.keep_delim_pattern, self.with_offsets)
class BasicTokenizer(TextTensorOperation): class BasicTokenizer(TextTensorOperation):
""" """
Tokenize a scalar tensor of UTF-8 string by specific rules. Tokenize a scalar tensor of UTF-8 string by specific rules.
@ -776,93 +696,201 @@ if platform.system().lower() != 'windows':
self.normalization_form, self.preserve_unused_token, self.with_offsets) self.normalization_form, self.preserve_unused_token, self.with_offsets)
class TruncateSequencePair(TextTensorOperation): class CaseFold(TextTensorOperation):
""" """
Truncate a pair of rank-1 tensors such that the total length is less than max_length. Apply case fold operation on UTF-8 string tensor.
This operation takes two input tensors and returns two output Tensors. Note:
CaseFold is not supported on Windows platform yet.
Args: Examples:
max_length (int): Maximum length required. >>> import mindspore.dataset.text as text
>>>
>>> case_op = text.CaseFold()
>>> data1 = data1.map(operations=case_op)
"""
Examples: def parse(self):
>>> import mindspore.dataset.text as text return cde.CaseFoldOperation()
>>>
>>> # Data before
>>> # | col1 | col2 |
>>> # +---------+---------|
>>> # | [1,2,3] | [4,5] |
>>> # +---------+---------+
>>> data1 = data1.map(operations=text.TruncateSequencePair(4))
>>> # Data after
>>> # | col1 | col2 |
>>> # +---------+---------+
>>> # | [1,2] | [4,5] |
>>> # +---------+---------+
"""
@check_pair_truncate
def __init__(self, max_length):
self.max_length = max_length
def parse(self):
return cde.TruncateSequencePairOperation(self.max_length)
class ToNumber(TextTensorOperation): class NormalizeUTF8(TextTensorOperation):
""" """
Tensor operation to convert every element of a string tensor to a number. Apply normalize operation on UTF-8 string tensor.
Strings are casted according to the rules specified in the following links: Note:
https://en.cppreference.com/w/cpp/string/basic_string/stof, NormalizeUTF8 is not supported on Windows platform yet.
https://en.cppreference.com/w/cpp/string/basic_string/stoul,
except that any strings which represent negative numbers cannot be cast to an
unsigned integer type.
Args: Args:
data_type (mindspore.dtype): mindspore.dtype to be casted to. Must be normalize_form (NormalizeForm, optional): Valid values can be any of [NormalizeForm.NONE,
a numeric type. NormalizeForm.NFC, NormalizeForm.NFKC, NormalizeForm.NFD,
NormalizeForm.NFKD](default=NormalizeForm.NFKC).
See http://unicode.org/reports/tr15/ for details.
Raises: - NormalizeForm.NONE, do nothing for input string tensor.
RuntimeError: If strings are invalid to cast, or are out of range after being casted. - NormalizeForm.NFC, normalize with Normalization Form C.
- NormalizeForm.NFKC, normalize with Normalization Form KC.
- NormalizeForm.NFD, normalize with Normalization Form D.
- NormalizeForm.NFKD, normalize with Normalization Form KD.
Examples: Examples:
>>> import mindspore.dataset.text as text >>> import mindspore.dataset.text as text
>>> import mindspore.common.dtype as mstype >>>
>>> >>> normalize_op = text.NormalizeUTF8(normalize_form=NormalizeForm.NFC)
>>> to_number_op = text.ToNumber(mstype.int8) >>> data1 = data1.map(operations=normalize_op)
>>> data1 = data1.map(operations=to_number_op) """
"""
@check_to_number def __init__(self, normalize_form=NormalizeForm.NFKC):
def __init__(self, data_type): if not isinstance(normalize_form, NormalizeForm):
data_type = mstype_to_detype(data_type) raise TypeError("Wrong input type for normalization_form, should be enum of 'NormalizeForm'.")
self.data_type = str(data_type)
def parse(self): normalize_form = replace_none(normalize_form, NormalizeForm.NFKC)
return cde.ToNumberOperation(self.data_type) self.normalize_form = DE_C_INTER_NORMALIZE_FORM[normalize_form]
def parse(self):
return cde.NormalizeUTF8Operation(self.normalize_form)
class PythonTokenizer: class RegexReplace(TextTensorOperation):
""" """
Callable class to be used for user-defined string tokenizer. Replace UTF-8 string tensor with 'replace' according to regular expression 'pattern'.
Args: See http://userguide.icu-project.org/strings/regexp for support regex pattern.
tokenizer (Callable): Python function that takes a `str` and returns a list of `str` as tokens.
Examples: Note:
>>> import mindspore.dataset.text as text RegexReplace is not supported on Windows platform yet.
>>>
>>> def my_tokenizer(line):
>>> return line.split()
>>> data1 = data1.map(operations=text.PythonTokenizer(my_tokenizer))
"""
@check_python_tokenizer Args:
def __init__(self, tokenizer): pattern (str): the regex expression patterns.
self.tokenizer = np.vectorize(lambda x: np.array(tokenizer(x), dtype='U'), signature='()->(n)') replace (str): the string to replace matched element.
replace_all (bool, optional): If False, only replace first matched element;
if True, replace all matched elements (default=True).
def __call__(self, in_array): Examples:
in_array = to_str(in_array) >>> import mindspore.dataset.text as text
tokens = self.tokenizer(in_array) >>>
return tokens >>> pattern = 'Canada'
>>> replace = 'China'
>>> replace_op = text.RegexReplace(pattern, replace)
>>> data1 = data1.map(operations=replace_op)
"""
def __init__(self, pattern, replace, replace_all=True):
self.pattern = pattern
self.replace = replace
self.replace_all = replace_all
def parse(self):
return cde.RegexReplaceOperation(self.pattern, self.replace, self.replace_all)
class RegexTokenizer(TextTensorOperation):
"""
Tokenize a scalar tensor of UTF-8 string by regex expression pattern.
See http://userguide.icu-project.org/strings/regexp for support regex pattern.
Note:
RegexTokenizer is not supported on Windows platform yet.
Args:
delim_pattern (str): The pattern of regex delimiters.
The original string will be split by matched elements.
keep_delim_pattern (str, optional): The string matched by 'delim_pattern' can be kept as a token
if it can be matched by 'keep_delim_pattern'. The default value is an empty str ('')
which means that delimiters will not be kept as an output token (default='').
with_offsets (bool, optional): If or not output offsets of tokens (default=False).
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.RegexTokenizer(delim_pattern, keep_delim_pattern, with_offsets=False)
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str],
>>> # ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.RegexTokenizer(delim_pattern, keep_delim_pattern, with_offsets=True)
>>> data2 = data2.map(operations=tokenizer_op, input_columns=["text"],
>>> output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
"""
@check_regex_tokenizer
def __init__(self, delim_pattern, keep_delim_pattern='', with_offsets=False):
self.delim_pattern = delim_pattern
self.keep_delim_pattern = keep_delim_pattern
self.with_offsets = with_offsets
def parse(self):
return cde.RegexTokenizerOperation(self.delim_pattern, self.keep_delim_pattern, self.with_offsets)
class UnicodeScriptTokenizer(TextTensorOperation):
"""
Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.
Note:
UnicodeScriptTokenizer is not supported on Windows platform yet.
Args:
keep_whitespace (bool, optional): If or not emit whitespace tokens (default=False).
with_offsets (bool, optional): If or not output offsets of tokens (default=False).
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.UnicodeScriptTokenizerOp(keep_whitespace=True, with_offsets=False)
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str],
>>> # ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.UnicodeScriptTokenizerOp(keep_whitespace=True, with_offsets=True)
>>> data2 = data2.map(operations=tokenizer_op, input_columns=["text"],
>>> output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
"""
@check_unicode_script_tokenizer
def __init__(self, keep_whitespace=False, with_offsets=False):
keep_whitespace = replace_none(keep_whitespace, False)
with_offsets = replace_none(with_offsets, False)
self.keep_whitespace = keep_whitespace
self.with_offsets = with_offsets
def parse(self):
return cde.UnicodeScriptTokenizerOperation(self.keep_whitespace, self.with_offsets)
class WhitespaceTokenizer(TextTensorOperation):
"""
Tokenize a scalar tensor of UTF-8 string on ICU4C defined whitespaces, such as: ' ', '\\\\t', '\\\\r', '\\\\n'.
Note:
WhitespaceTokenizer is not supported on Windows platform yet.
Args:
with_offsets (bool, optional): If or not output offsets of tokens (default=False).
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.WhitespaceTokenizer()
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str],
>>> # ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.WhitespaceTokenizer(True)
>>> data2 = data2.map(operations=tokenizer_op, input_columns=["text"],
>>> output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
"""
@check_with_offsets
def __init__(self, with_offsets=False):
self.with_offsets = with_offsets
def parse(self):
return cde.WhitespaceTokenizerOperation(self.with_offsets)

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -26,6 +26,14 @@ from .validators import check_num_classes, check_de_type, check_fill_value, chec
from ..core.datatypes import mstype_to_detype from ..core.datatypes import mstype_to_detype
class TensorOperation:
def __call__(self):
raise NotImplementedError("TensorOperation has to implement __call__() method.")
def parse(self):
raise NotImplementedError("TensorOperation has to implement parse() method.")
class OneHot(cde.OneHotOp): class OneHot(cde.OneHotOp):
""" """
Tensor operation to apply one hot encoding. Tensor operation to apply one hot encoding.
@ -304,7 +312,7 @@ class Unique(cde.UniqueOp):
Also return an index tensor that contains the index of each element of the Also return an index tensor that contains the index of each element of the
input tensor in the Unique output tensor. input tensor in the Unique output tensor.
Finally, return a count tensor that constains the count of each element of Finally, return a count tensor that contains the count of each element of
the output tensor in the input tensor. the output tensor in the input tensor.
Note: Note:

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -237,8 +237,8 @@ def check_compose_list(method):
type_check(transforms, (list,), transforms) type_check(transforms, (list,), transforms)
if not transforms: if not transforms:
raise ValueError("transforms list is empty.") raise ValueError("transforms list is empty.")
for i, transfrom in enumerate(transforms): for i, transform in enumerate(transforms):
if not callable(transfrom): if not callable(transform):
raise ValueError("transforms[{}] is not callable.".format(i)) raise ValueError("transforms[{}] is not callable.".format(i))
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
@ -269,9 +269,10 @@ def check_random_apply(method):
[transforms, prob], _ = parse_user_args(method, *args, **kwargs) [transforms, prob], _ = parse_user_args(method, *args, **kwargs)
type_check(transforms, (list,), "transforms") type_check(transforms, (list,), "transforms")
for i, transfrom in enumerate(transforms): for i, transform in enumerate(transforms):
if not callable(transfrom): if str(transform).find("c_transform") >= 0:
raise ValueError("transforms[{}] is not callable.".format(i)) raise ValueError("transforms[{}] is not a py transforms. Should not use a c transform in py transform" \
.format(i))
if prob is not None: if prob is not None:
type_check(prob, (float, int,), "prob") type_check(prob, (float, int,), "prob")
@ -290,9 +291,10 @@ def check_transforms_list(method):
[transforms], _ = parse_user_args(method, *args, **kwargs) [transforms], _ = parse_user_args(method, *args, **kwargs)
type_check(transforms, (list,), "transforms") type_check(transforms, (list,), "transforms")
for i, transfrom in enumerate(transforms): for i, transform in enumerate(transforms):
if not callable(transfrom): if str(transform).find("c_transform") >= 0:
raise ValueError("transforms[{}] is not callable.".format(i)) raise ValueError("transforms[{}] is not a py transforms. Should not use a c transform in py transform" \
.format(i))
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -29,6 +29,20 @@ DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_HWC2CHW_callable():
"""
Test HWC2CHW is callable
"""
logger.info("Test HWC2CHW callable")
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
img = c_vision.Decode()(img)
img = c_vision.HWC2CHW()(img)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
assert img.shape == (3, 2268, 4032)
def test_HWC2CHW(plot=False): def test_HWC2CHW(plot=False):
""" """
Test HWC2CHW Test HWC2CHW
@ -122,6 +136,7 @@ def test_HWC2CHW_comp(plot=False):
if __name__ == '__main__': if __name__ == '__main__':
test_HWC2CHW_callable()
test_HWC2CHW(True) test_HWC2CHW(True)
test_HWC2CHW_md5() test_HWC2CHW_md5()
test_HWC2CHW_comp(True) test_HWC2CHW_comp(True)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -219,7 +219,7 @@ def test_c_py_compose_vision_module(plot=False, run_golden=True):
def test_py_transforms_with_c_vision(): def test_py_transforms_with_c_vision():
""" """
These examples will fail, as py_transforms.Random(Apply/Choice/Order) expect callable functions These examples will fail, as c_transform should not be used in py_transforms.Random(Apply/Choice/Order)
""" """
ds.config.set_seed(0) ds.config.set_seed(0)
@ -236,15 +236,15 @@ def test_py_transforms_with_c_vision():
with pytest.raises(ValueError) as error_info: with pytest.raises(ValueError) as error_info:
test_config(py_transforms.RandomApply([c_vision.RandomResizedCrop(200)])) test_config(py_transforms.RandomApply([c_vision.RandomResizedCrop(200)]))
assert "transforms[0] is not callable." in str(error_info.value) assert "transforms[0] is not a py transforms." in str(error_info.value)
with pytest.raises(ValueError) as error_info: with pytest.raises(ValueError) as error_info:
test_config(py_transforms.RandomChoice([c_vision.RandomResizedCrop(200)])) test_config(py_transforms.RandomChoice([c_vision.RandomResizedCrop(200)]))
assert "transforms[0] is not callable." in str(error_info.value) assert "transforms[0] is not a py transforms." in str(error_info.value)
with pytest.raises(ValueError) as error_info: with pytest.raises(ValueError) as error_info:
test_config(py_transforms.RandomOrder([np.array, c_vision.RandomResizedCrop(200)])) test_config(py_transforms.RandomOrder([np.array, c_vision.RandomResizedCrop(200)]))
assert "transforms[1] is not callable." in str(error_info.value) assert "transforms[1] is not a py transforms." in str(error_info.value)
with pytest.raises(RuntimeError) as error_info: with pytest.raises(RuntimeError) as error_info:
test_config([py_transforms.OneHotOp(20, 0.1)]) test_config([py_transforms.OneHotOp(20, 0.1)])

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -29,6 +29,21 @@ DATA_DIR = "../data/dataset/testImageNetData/train/"
GENERATE_GOLDEN = False GENERATE_GOLDEN = False
def test_invert_callable():
"""
Test Invert is callable
"""
logger.info("Test Invert callable")
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
img = C.Decode()(img)
img = C.Invert()(img)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
assert img.shape == (2268, 4032, 3)
def test_invert_py(plot=False): def test_invert_py(plot=False):
""" """
Test Invert python op Test Invert python op
@ -247,6 +262,7 @@ def test_invert_md5_c():
if __name__ == "__main__": if __name__ == "__main__":
test_invert_callable()
test_invert_py(plot=False) test_invert_py(plot=False)
test_invert_c(plot=False) test_invert_c(plot=False)
test_invert_py_c(plot=False) test_invert_py_c(plot=False)

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -34,6 +34,22 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
GENERATE_GOLDEN = False GENERATE_GOLDEN = False
def test_random_crop_and_resize_callable():
"""
Test RandomCropAndResize op is callable
"""
logger.info("test_random_crop_and_resize_callable")
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
decode_op = c_vision.Decode()
img = decode_op(img)
random_crop_and_resize_op = c_vision.RandomResizedCrop((256, 512), (2, 2), (1, 3))
img = random_crop_and_resize_op(img)
assert np.shape(img) == (256, 512, 3)
def test_random_crop_and_resize_op_c(plot=False): def test_random_crop_and_resize_op_c(plot=False):
""" """
Test RandomCropAndResize op in c transforms Test RandomCropAndResize op in c transforms
@ -389,6 +405,7 @@ def test_random_crop_and_resize_06():
if __name__ == "__main__": if __name__ == "__main__":
test_random_crop_and_resize_callable()
test_random_crop_and_resize_op_c(True) test_random_crop_and_resize_op_c(True)
test_random_crop_and_resize_op_py(True) test_random_crop_and_resize_op_py(True)
test_random_crop_and_resize_op_py_ANTIALIAS() test_random_crop_and_resize_op_py_ANTIALIAS()

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@ import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore.dataset.text import JiebaTokenizer from mindspore.dataset.text import JiebaTokenizer
from mindspore.dataset.text import JiebaMode, to_str from mindspore.dataset.text import JiebaMode, to_str
from mindspore import log as logger
DATA_FILE = "../data/dataset/testJiebaDataset/3.txt" DATA_FILE = "../data/dataset/testJiebaDataset/3.txt"
DATA_ALL_FILE = "../data/dataset/testJiebaDataset/*" DATA_ALL_FILE = "../data/dataset/testJiebaDataset/*"
@ -24,6 +25,23 @@ HMM_FILE = "../data/dataset/jiebadict/hmm_model.utf8"
MP_FILE = "../data/dataset/jiebadict/jieba.dict.utf8" MP_FILE = "../data/dataset/jiebadict/jieba.dict.utf8"
def test_jieba_callable():
"""
Test jieba tokenizer op is callable
"""
logger.info("test_jieba_callable")
jieba_op1 = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP)
jieba_op2 = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.HMM)
text1 = "今天天气太好了我们一起去外面玩吧"
text2 = "男默女泪市长江大桥"
assert np.array_equal(jieba_op1(text1), ['今天天气', '太好了', '我们', '一起', '', '外面', '玩吧'])
assert np.array_equal(jieba_op2(text1), ['今天', '天气', '', '', '', '我们', '一起', '', '外面', '', ''])
jieba_op1.add_word("男默女泪")
assert np.array_equal(jieba_op1(text2), ['男默女泪', '', '长江大桥'])
def test_jieba_1(): def test_jieba_1():
"""Test jieba tokenizer with MP mode""" """Test jieba tokenizer with MP mode"""
data = ds.TextFileDataset(DATA_FILE) data = ds.TextFileDataset(DATA_FILE)
@ -457,6 +475,7 @@ def test_jieba_6():
if __name__ == "__main__": if __name__ == "__main__":
test_jieba_callable()
test_jieba_1() test_jieba_1()
test_jieba_1_1() test_jieba_1_1()
test_jieba_1_2() test_jieba_1_2()

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -28,6 +28,24 @@ from util import visualize_list, diff_mse
DATA_DIR = "../data/dataset/testImageNetData/train/" DATA_DIR = "../data/dataset/testImageNetData/train/"
def test_uniform_augment_callable(num_ops=2):
"""
Test UniformAugment is callable
"""
logger.info("test_uniform_augment_callable")
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
decode_op = C.Decode()
img = decode_op(img)
transforms_ua = [C.RandomCrop(size=[400, 400], padding=[32, 32, 32, 32]),
C.RandomCrop(size=[400, 400], padding=[32, 32, 32, 32])]
uni_aug = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
img = uni_aug([img, img])
assert ((np.shape(img) == (2, 2268, 4032, 3)) or (np.shape(img) == (1, 400, 400, 3)))
def test_uniform_augment(plot=False, num_ops=2): def test_uniform_augment(plot=False, num_ops=2):
""" """
Test UniformAugment Test UniformAugment
@ -262,6 +280,7 @@ def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
if __name__ == "__main__": if __name__ == "__main__":
test_uniform_augment_callable(num_ops=2)
test_uniform_augment(num_ops=1, plot=True) test_uniform_augment(num_ops=1, plot=True)
test_cpp_uniform_augment(num_ops=1, plot=True) test_cpp_uniform_augment(num_ops=1, plot=True)
test_cpp_uniform_augment_exception_pyops(num_ops=1) test_cpp_uniform_augment_exception_pyops(num_ops=1)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@ import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.text as text import mindspore.dataset.text as text
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import log as logger
# this file contains "home is behind the world head" each word is 1 line # this file contains "home is behind the world head" each word is 1 line
DATA_FILE = "../data/dataset/testVocab/words.txt" DATA_FILE = "../data/dataset/testVocab/words.txt"
@ -25,6 +26,16 @@ VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt"
SIMPLE_VOCAB_FILE = "../data/dataset/testVocab/simple_vocab_list.txt" SIMPLE_VOCAB_FILE = "../data/dataset/testVocab/simple_vocab_list.txt"
def test_lookup_callable():
"""
Test lookup is callable
"""
logger.info("test_lookup_callable")
vocab = text.Vocab.from_list(['', '', '', '', ''])
lookup = text.Lookup(vocab)
word = ""
assert lookup(word) == 3
def test_from_list_tutorial(): def test_from_list_tutorial():
vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", "<unk>"], True) vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", "<unk>"], True)
lookup = text.Lookup(vocab, "<unk>") lookup = text.Lookup(vocab, "<unk>")
@ -171,6 +182,7 @@ def test_lookup_cast_type():
if __name__ == '__main__': if __name__ == '__main__':
test_lookup_callable()
test_from_dict_exception() test_from_dict_exception()
test_from_list_tutorial() test_from_list_tutorial()
test_from_file_tutorial() test_from_file_tutorial()