fix bug for type check

This commit is contained in:
xulei2020 2021-03-31 15:05:51 +08:00
parent 5c521492da
commit 354da691d5
3 changed files with 24 additions and 4 deletions

View File

@ -51,7 +51,8 @@ from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPi
from .validators import check_lookup, check_jieba_add_dict, \
check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer, \
check_wordpiece_tokenizer, check_regex_replace, check_regex_tokenizer, check_basic_tokenizer, check_ngram, \
check_pair_truncate, check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow
check_pair_truncate, check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow, \
check_sentence_piece_tokenizer
from ..core.datatypes import mstype_to_detype
from ..core.validator_helpers import replace_none
from ..transforms.c_transforms import TensorOperation
@ -325,7 +326,7 @@ class SentencePieceTokenizer(TextTensorOperation):
Args:
mode (Union[str, SentencePieceVocab]): If the input parameter is a file, then it is of type string.
If the input parameter is a SentencePieceVocab object, then it is of type SentencePieceVocab.
out_type (Union[str, int]): The type of output.
out_type (SPieceTokenizerOutType): The type of output, the type is int or string
Examples:
>>> from mindspore.dataset.text import SentencePieceModel, SPieceTokenizerOutType
@ -335,7 +336,7 @@ class SentencePieceTokenizer(TextTensorOperation):
>>> tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING)
>>> text_file_dataset = text_file_dataset.map(operations=tokenizer)
"""
@check_sentence_piece_tokenizer
def __init__(self, mode, out_type):
self.mode = mode
self.out_type = out_type

View File

@ -515,3 +515,19 @@ def check_save_model(method):
return method(self, *args, **kwargs)
return new_method
def check_sentence_piece_tokenizer(method):
"""A wrapper that wraps a parameter checker to the original function."""
from .utils import SPieceTokenizerOutType
@wraps(method)
def new_method(self, *args, **kwargs):
[mode, out_type], _ = parse_user_args(method, *args, **kwargs)
type_check(mode, (str, cde.SentencePieceVocab), "mode is not an instance of str or cde.SentencePieceVocab.")
type_check(out_type, (SPieceTokenizerOutType,), "out_type is not an instance of SPieceTokenizerOutType")
return method(self, *args, **kwargs)
return new_method

View File

@ -31,7 +31,8 @@ def check_crop_size(size):
if isinstance(size, int):
check_value(size, (1, FLOAT_MAX_INTEGER))
elif isinstance(size, (tuple, list)) and len(size) == 2:
for value in size:
for index, value in enumerate(size):
type_check(value, (int,), "size[{}]".format(index))
check_value(value, (1, FLOAT_MAX_INTEGER))
else:
raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
@ -93,6 +94,8 @@ def check_normalize_c_param(mean, std):
def check_normalize_py_param(mean, std):
type_check(mean, (list, tuple), "mean")
type_check(std, (list, tuple), "std")
if len(mean) != len(std):
raise ValueError("Length of mean and std must be equal.")
for mean_value in mean: