Fix CI warning and test team issues

Signed-off-by: alex-yuyue <yue.yu1@huawei.com>
This commit is contained in:
alex-yuyue 2021-03-10 10:36:53 -05:00
parent 9add98350e
commit 2457f52596
8 changed files with 47 additions and 19 deletions

View File

@ -492,6 +492,5 @@ PYBIND_REGISTER(
return uniform_aug;
}));
}));
} // namespace dataset
} // namespace mindspore

View File

@ -25,7 +25,6 @@
namespace mindspore {
namespace dataset {
#ifdef ENABLE_ICU4C
PYBIND_REGISTER(
@ -262,6 +261,5 @@ PYBIND_REGISTER(SPieceTokenizerOutType, 0, ([](const py::module *m) {
.value("DE_SPIECE_TOKENIZER_OUTTYPE_KINT", SPieceTokenizerOutType::kInt)
.export_values();
}));
} // namespace dataset
} // namespace mindspore

View File

@ -605,7 +605,7 @@ class SubsetSampler(BuiltinSampler):
Samples the elements from a sequence of indices.
Args:
indices (list[int]): A sequence of indices.
indices (Any iterable python object but string): A sequence of indices.
num_samples (int, optional): Number of elements to sample (default=None, all elements).
Examples:
@ -633,6 +633,13 @@ class SubsetSampler(BuiltinSampler):
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
if num_samples is not None:
if not isinstance(num_samples, int):
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
if num_samples < 0 or num_samples > validator.INT64_MAX:
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
.format(0, validator.INT64_MAX))
if not isinstance(indices, str) and validator.is_iterable(indices):
indices = _get_sample_ids_as_list(indices, num_samples)
elif isinstance(indices, int):
@ -645,13 +652,6 @@ class SubsetSampler(BuiltinSampler):
raise TypeError("SubsetSampler: Type of indices element must be int, "
"but got list[{}]: {}, type: {}.".format(i, item, type(item)))
if num_samples is not None:
if not isinstance(num_samples, int):
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
if num_samples < 0 or num_samples > validator.INT64_MAX:
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
.format(0, validator.INT64_MAX))
self.indices = indices
super().__init__(num_samples)

View File

@ -31,7 +31,13 @@ class TensorOperation:
Base class Tensor Ops
"""
def __call__(self, *input_tensor_list):
tensor_row = [cde.Tensor(np.asarray(tensor)) for tensor in input_tensor_list]
tensor_row = []
for tensor in input_tensor_list:
try:
tensor_row.append(cde.Tensor(np.asarray(tensor)))
except RuntimeError:
raise TypeError("Invalid user input. Got {}: {}, cannot be converted into tensor." \
.format(type(tensor), tensor))
callable_op = cde.Execute(self.parse())
output_tensor_list = callable_op(tensor_row)
for i, element in enumerate(output_tensor_list):

View File

@ -1197,12 +1197,13 @@ class RandomSharpness(ImageTensorOperation):
class RandomSolarize(ImageTensorOperation):
"""
Invert all pixel values with given range.
Randomly invert the pixel values of input image within given range.
Args:
threshold (tuple, optional): Range of random solarize threshold. Threshold values should always be
in the range (0, 255), include at least one integer value in the given range and be in
(min, max) format. If min=max, then invert all pixel values above min(max) (default=(0, 255)).
threshold (tuple, optional): Range of random solarize threshold (default=(0, 255)).
Threshold values should always be in (min, max) format,
where min <= max, min and max are integers in the range (0, 255).
If min=max, then invert all pixel values above min(max).
Examples:
>>> transforms_list = [c_vision.Decode(), c_vision.RandomSolarize(threshold=(10,100))]

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import copy
import numpy as np
import mindspore.dataset.text as text
import mindspore.dataset as ds
from mindspore.dataset.text import SentencePieceModel, to_str, SPieceTokenizerOutType
@ -21,6 +22,13 @@ VOCAB_FILE = "../data/dataset/test_sentencepiece/botchan.txt"
DATA_FILE = "../data/dataset/testTokenizerData/sentencepiece_tokenizer.txt"
def test_sentence_piece_tokenizer_callable():
vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.UNIGRAM, {})
tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING)
data = '123'
assert np.array_equal(tokenizer(data), ['', '12', '3'])
def test_from_vocab_to_str_UNIGRAM():
vocab = text.SentencePieceVocab.from_file([VOCAB_FILE], 5000, 0.9995, SentencePieceModel.UNIGRAM, {})
tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING)
@ -160,6 +168,7 @@ def test_with_zip_concat():
if __name__ == "__main__":
test_sentence_piece_tokenizer_callable()
test_from_vocab_to_str_UNIGRAM()
test_from_vocab_to_str_BPE()
test_from_vocab_to_str_CHAR()

View File

@ -16,6 +16,7 @@
Testing BertTokenizer op in DE
"""
import numpy as np
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
import mindspore.dataset.text as text
@ -127,7 +128,7 @@ test_paras = [
preserve_unused_token=True,
vocab_list=vocab_bert
),
# test non-default parms
# test non-default params
dict(
first=8,
last=8,
@ -242,6 +243,19 @@ def test_bert_tokenizer_with_offsets():
check_bert_tokenizer_with_offsets(**paras)
def test_bert_tokenizer_callable_invalid_input():
"""
Test WordpieceTokenizer in eager mode with invalid input
"""
data = {'张三': 18, '王五': 20}
vocab = text.Vocab.from_list(vocab_bert)
tokenizer_op = text.BertTokenizer(vocab=vocab)
with pytest.raises(TypeError) as info:
_ = tokenizer_op(data)
assert "Invalid user input. Got <class 'dict'>: {'张三': 18, '王五': 20}, cannot be converted into tensor." in str(info)
if __name__ == '__main__':
test_bert_tokenizer_callable_invalid_input()
test_bert_tokenizer_default()
test_bert_tokenizer_with_offsets()

View File

@ -52,9 +52,10 @@ def test_to_number_eager():
# test input invalid tensor
invalid_input = [["1", "2", "3"], ["4", "5"]]
with pytest.raises(RuntimeError) as info:
with pytest.raises(TypeError) as info:
_ = op(invalid_input)
assert "Invalid data type." in str(info.value)
assert "Invalid user input. Got <class 'list'>: [['1', '2', '3'], ['4', '5']], cannot be converted into tensor" in \
str(info.value)
def test_to_number_typical_case_integral():