diff --git a/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc index 2825950f253..ecf91c36e1f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc @@ -141,6 +141,11 @@ Status BuildVocabOp::CollectorThread() { } } int64_t num_words = std::min(static_cast(words.size()), top_k_); + if (num_words == 0) { + MS_LOG(WARNING) << "No word falls in the frequency range: (" << freq_range_.first << "," << freq_range_.second + << ") vocab would be empty (except for special tokens)."; + } + // this would take the top-k most frequent words std::partial_sort(words.begin(), words.begin() + num_words, words.end(), [this](const std::string &w1, const std::string &w2) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc index d0973d91db6..5cdfa8bb767 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc @@ -323,9 +323,7 @@ Status ImageFolderOp::PrescanWorkerEntry(int32_t worker_id) { // if mRecursive == false, don't go into folder of folders Status ImageFolderOp::RecursiveWalkFolder(Path *dir) { std::shared_ptr dir_itr = Path::DirIterator::OpenDirectory(dir); - if (dir_itr == nullptr) { - RETURN_STATUS_UNEXPECTED("Error encountered when indexing files"); - } + RETURN_UNEXPECTED_IF_NULL(dir_itr); while (dir_itr->hasNext()) { Path subdir = dir_itr->next(); if (subdir.IsDirectory()) { diff --git a/mindspore/dataset/engine/__init__.py b/mindspore/dataset/engine/__init__.py index cc3686eb002..674848f156c 100644 --- a/mindspore/dataset/engine/__init__.py +++ b/mindspore/dataset/engine/__init__.py @@ -32,6 +32,5 @@ __all__ = ["config", "ConfigurationManager", "zip", "ImageFolderDatasetV2", "MnistDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", - "VOCDataset", "CocoDataset", "TextFileDataset", "BuildVocabDataset", "Schema", "Schema", - "DistributedSampler", "PKSampler", - "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] + "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", + "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 91785d15c17..b2b08472f62 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -42,8 +42,8 @@ from .iterators import DictIterator, TupleIterator from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ check_rename, check_numpyslicesdataset, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ - check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset,\ - check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat,\ + check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ + check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ check_split, check_bucket_batch_by_length, check_cluedataset from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist @@ -895,26 +895,7 @@ class Dataset: return ProjectDataset(self, columns) def build_vocab(self, vocab, columns, freq_range, top_k): - """ - Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab - which contains top_k most frequent words (if top_k is specified) - This function is not meant to be called directly by user. To build vocab, please use the function - text.Vocab.from_dataset() - - Args: - vocab(Vocab): vocab object - columns(str or list, optional): column names to get words from. It can be a list of column names. - (Default is None where all columns will be used. If any column isn't string type, will return error) - freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency - range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency - can be None, which corresponds to 0/total_words separately (default is None, all words are included) - top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are - taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default is None - all words are included) - - Returns: - BuildVocabDataset - """ + """ Internal function for building a vocab""" return BuildVocabDataset(self, vocab, columns, freq_range, top_k) def apply(self, apply_func): @@ -1468,6 +1449,7 @@ class DatasetOp(Dataset): # No need for __init__ since it is the same as the super's init + class BucketBatchByLengthDataset(DatasetOp): """ The result of applying BucketBatchByLength operator to the input dataset. @@ -1608,7 +1590,7 @@ class BatchDataset(DatasetOp): Args: dataset (Dataset): dataset to be checked. - batchsize (int): batch size to notify. + batch_size (int): batch size to notify. """ if isinstance(dataset, SyncWaitDataset): dataset.update_sync_batch_size(batch_size) @@ -1646,7 +1628,7 @@ class BlockReleasePair: Args: init_release_rows (int): Number of lines to allow through the pipeline. - callback (function): The callback funciton that will be called when release is called. + callback (function): The callback function that will be called when release is called. """ def __init__(self, init_release_rows, callback=None): @@ -1710,7 +1692,7 @@ class SyncWaitDataset(DatasetOp): input_dataset (Dataset): Input dataset to apply flow control. num_batch (int): the number of batches without blocking at the start of each epoch. condition_name (str): The condition name that is used to toggle sending next row. - callback (function): The callback funciton that will be invoked when sync_update is called. + callback (function): The callback function that will be invoked when sync_update is called. Raises: RuntimeError: If condition name already exists. @@ -2066,7 +2048,7 @@ class SkipDataset(DatasetOp): The result of applying Skip operator to the input Dataset. Args: - datasets (tuple): A tuple of datasets to be skipped. + input_dataset (tuple): A tuple of datasets to be skipped. count (int): Number of rows the dataset should be skipped. """ @@ -3055,7 +3037,7 @@ class GeneratorDataset(MappableDataset): provide either column_names or schema. column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None). If provided, sanity check will be performed on generator output. - schema (Schema/String, optional): Path to the json schema file or schema object (default=None). Users are + schema (Schema/str, optional): Path to the json schema file or schema object (default=None). Users are required to provide either column_names or schema. If both are provided, schema will be used. num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images). @@ -4343,7 +4325,7 @@ class CelebADataset(MappableDataset): dataset_dir (str): Path to the root directory that contains the dataset. num_parallel_workers (int, optional): Number of workers to read the data (default=value set in the config). shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None). - dataset_type (string): one of 'all', 'train', 'valid' or 'test'. + dataset_type (str): one of 'all', 'train', 'valid' or 'test'. sampler (Sampler, optional): Object used to choose samples from the dataset (default=None). decode (bool, optional): decode the images after reading (default=False). extensions (list[str], optional): List of file extensions to be @@ -4874,18 +4856,15 @@ class BuildVocabDataset(DatasetOp): text.Vocab.from_dataset() Args: - vocab(Vocab): vocab object. + vocab(Vocab): text.vocab object. columns(str or list, optional): column names to get words from. It can be a list of column names (Default is - None, all columns are used, return error if any column isn't string). + None, all columns are used, return error if any column isn't string). freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency - range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency - can be None, which corresponds to 0/total_words separately (default is None, all words are included). + range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency + can be None, which corresponds to 0/total_words separately (default=None, all words are included). top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are - taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken (default is None - all words are included). - - Returns: - BuildVocabDataset + taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken (default=None, + all words are included). """ def __init__(self, input_dataset, vocab, columns, freq_range, top_k, prefetch_size=None): diff --git a/mindspore/dataset/text/transforms.py b/mindspore/dataset/text/transforms.py index 4a64ed3c424..7984a560603 100644 --- a/mindspore/dataset/text/transforms.py +++ b/mindspore/dataset/text/transforms.py @@ -30,8 +30,8 @@ class Lookup(cde.LookupOp): """ Lookup operator that looks up a word to an id Args: - vocab(Vocab): a Vocab object - unknown(None,int): default id to lookup a word that is out of vocab + vocab(Vocab): a Vocab object. + unknown(int): default id to lookup a word that is out of vocab (default is None). """ @check_lookup @@ -45,16 +45,21 @@ class Lookup(cde.LookupOp): class Ngram(cde.NgramOp): """ TensorOp to generate n-gram from a 1-D string Tensor - Refer to https://en.wikipedia.org/wiki/N-gram#Examples for an explanation of what n-gram is. + Refer to https://en.wikipedia.org/wiki/N-gram#Examples for an overview of what n-gram is and how it works. + Args: - n(int or list): n in n-gram, n >= 1. n is a list of positive integers, for e.g. n=[4,3], The result - would be a 4-gram followed by a 3-gram in the same tensor. - left_pad(tuple, optional): ("pad_token",pad_width). Padding performed on left side of the sequence. pad_width - will be capped at n-1. left_pad=("_",2) would pad left side of the sequence with "__". (Default is None) - right_pad(tuple, optional): ("pad_token",pad_width). Padding performed on right side of the sequence. pad_width - will be capped at n-1. right_pad=("-":2) would pad right side of the sequence with "--". (Default is None) + n([int, list]): n in n-gram, n >= 1. n is a list of positive integers, for e.g. n=[4,3], The result + would be a 4-gram followed by a 3-gram in the same tensor. If number of words is not enough to make up for + a n-gram, an empty string would be returned. For e.g. 3 grams on ["mindspore","best"] would result in an + empty string be produced. + left_pad(tuple, optional): ("pad_token", pad_width). Padding performed on left side of the sequence. pad_width + will be capped at n-1. left_pad=("_",2) would pad left side of the sequence with "__" (Default is None). + right_pad(tuple, optional): ("pad_token", pad_width). Padding performed on right side of the sequence. + pad_width will be capped at n-1. right_pad=("-":2) would pad right side of the sequence with "--" + (Default is None). separator(str,optional): symbol used to join strings together. for e.g. if 2-gram the ["mindspore", "amazing"] - with separator="-" the result would be ["mindspore-amazing"]. (Default is None which means whitespace is used) + with separator="-" the result would be ["mindspore-amazing"] (Default is None which means whitespace is + used). """ @check_ngram diff --git a/mindspore/dataset/text/utils.py b/mindspore/dataset/text/utils.py index 87be604cbbf..819417623a7 100644 --- a/mindspore/dataset/text/utils.py +++ b/mindspore/dataset/text/utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Some basic function for nlp +Some basic function for text """ from enum import IntEnum @@ -25,42 +25,47 @@ from .validators import check_from_file, check_from_list, check_from_dict, check class Vocab(cde.Vocab): """ - Vocab object that is used for lookup word + Vocab object that is used for lookup word. """ @classmethod @check_from_dataset def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None): """ - Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab - which contains top_k most frequent words (if top_k is specified) + Build a vocab from a dataset. This would collect all unique words in a dataset and return a vocab within + the frequency range specified by user in freq_range. User would be warned if no words fall into the frequency. + Words in vocab are ordered from highest frequency to lowest frequency. Words with the same frequency would be + ordered lexicographically. + Args: dataset(Dataset): dataset to build vocab from. - columns(str or list, optional): column names to get words from. It can be a list of column names. - (Default is None where all columns will be used. If any column isn't string type, will return error) + columns([str, list], optional): column names to get words from. It can be a list of column names. + (Default=None where all columns will be used. If any column isn't string type, will return error) freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency - range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency - can be None, which corresponds to 0/total_words separately (default is None, all words are included) + range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency=0 is the same as + min_frequency=1. max_frequency > total_words is the same as max_frequency = total_words. + min_frequency/max_frequency can be None, which corresponds to 0/total_words separately + (default=None, all words are included). top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are - taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default is None - all words are included) + taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default=None + all words are included). return: - text.Vocab: vocab object built from dataset. + text.Vocab: Vocab object built from dataset. """ vocab = Vocab() root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k) for d in root.create_dict_iterator(): if d is not None: - raise ValueError("from_dataset should receive data other than None") + raise ValueError("from_dataset should receive data other than None.") return vocab @classmethod @check_from_list def from_list(cls, word_list): """ - build a vocab object from a list of word + build a vocab object from a list of word. Args: - word_list(list): a list of string where each element is a word + word_list(list): a list of string where each element is a word. """ return super().from_list(word_list) @@ -68,11 +73,12 @@ class Vocab(cde.Vocab): @check_from_file def from_file(cls, file_path, delimiter=None, vocab_size=None): """ - build a vocab object from a list of word + build a vocab object from a list of word. Args: - file_path(str): path to the file which contains the vocab list - delimiter(None, str): a delimiter to break up each line in file, the first element is taken to be the word - vocab_size(None, int): number of words to read from file_path + file_path(str): path to the file which contains the vocab list. + delimiter(str, optional): a delimiter to break up each line in file, the first element is taken to be + the word (default=None). + vocab_size(int, optional): number of words to read from file_path (default=None, all words are taken). """ return super().from_file(file_path, delimiter, vocab_size) @@ -82,7 +88,7 @@ class Vocab(cde.Vocab): """ build a vocab object from a dict. Args: - word_dict(dict): dict contains word, id pairs. id should start from 2 and continuous + word_dict(dict): dict contains word, id pairs. id should start from 2 and be continuous. """ return super().from_dict(word_dict) @@ -100,7 +106,7 @@ def to_str(array, encoding='utf8'): """ if not isinstance(array, np.ndarray): - raise ValueError('input should be a numpy array') + raise ValueError('input should be a numpy array.') return np.char.decode(array, encoding) @@ -118,7 +124,7 @@ def to_bytes(array, encoding='utf8'): """ if not isinstance(array, np.ndarray): - raise ValueError('input should be a numpy array') + raise ValueError('input should be a numpy array.') return np.char.encode(array, encoding) diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index 4004bf40a43..9c39175aa6a 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -24,7 +24,7 @@ from ..transforms.validators import check_uint32, check_pos_int64 def check_lookup(method): - """A wrapper that wrap a parameter checker to the original function(crop operation).""" + """A wrapper that wrap a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -35,10 +35,10 @@ def check_lookup(method): unknown = kwargs.get("unknown") if unknown is not None: if not (isinstance(unknown, int) and unknown >= 0): - raise ValueError("unknown needs to be a non-negative integer") + raise ValueError("unknown needs to be a non-negative integer.") if not isinstance(vocab, cde.Vocab): - raise ValueError("vocab is not an instance of cde.Vocab") + raise ValueError("vocab is not an instance of cde.Vocab.") kwargs["vocab"] = vocab kwargs["unknown"] = unknown @@ -48,7 +48,7 @@ def check_lookup(method): def check_from_file(method): - """A wrapper that wrap a parameter checker to the original function(crop operation).""" + """A wrapper that wrap a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -61,16 +61,16 @@ def check_from_file(method): vocab_size = kwargs.get("vocab_size") if not isinstance(file_path, str): - raise ValueError("file_path needs to be str") + raise ValueError("file_path needs to be str.") if delimiter is not None: if not isinstance(delimiter, str): - raise ValueError("delimiter needs to be str") + raise ValueError("delimiter needs to be str.") else: delimiter = "" if vocab_size is not None: if not (isinstance(vocab_size, int) and vocab_size > 0): - raise ValueError("vocab size needs to be a positive integer") + raise ValueError("vocab size needs to be a positive integer.") else: vocab_size = -1 kwargs["file_path"] = file_path @@ -82,7 +82,7 @@ def check_from_file(method): def check_from_list(method): - """A wrapper that wrap a parameter checker to the original function(crop operation).""" + """A wrapper that wrap a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -90,10 +90,10 @@ def check_from_list(method): if "word_list" in kwargs: word_list = kwargs.get("word_list") if not isinstance(word_list, list): - raise ValueError("word_list needs to be a list of words") + raise ValueError("word_list needs to be a list of words.") for word in word_list: if not isinstance(word, str): - raise ValueError("each word in word list needs to be type str") + raise ValueError("each word in word list needs to be type str.") kwargs["word_list"] = word_list return method(self, **kwargs) @@ -102,7 +102,7 @@ def check_from_list(method): def check_from_dict(method): - """A wrapper that wrap a parameter checker to the original function(crop operation).""" + """A wrapper that wrap a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -110,12 +110,12 @@ def check_from_dict(method): if "word_dict" in kwargs: word_dict = kwargs.get("word_dict") if not isinstance(word_dict, dict): - raise ValueError("word_dict needs to be a list of word,id pairs") + raise ValueError("word_dict needs to be a list of word,id pairs.") for word, word_id in word_dict.items(): if not isinstance(word, str): - raise ValueError("each word in word_dict needs to be type str") + raise ValueError("each word in word_dict needs to be type str.") if not (isinstance(word_id, int) and word_id >= 0): - raise ValueError("each word id needs to be positive integer") + raise ValueError("each word id needs to be positive integer.") kwargs["word_dict"] = word_dict return method(self, **kwargs) @@ -135,11 +135,11 @@ def check_jieba_init(method): mp_path = kwargs.get("mp_path") if hmm_path is None: raise ValueError( - "the dict of HMMSegment in cppjieba is not provided") + "the dict of HMMSegment in cppjieba is not provided.") kwargs["hmm_path"] = hmm_path if mp_path is None: raise ValueError( - "the dict of MPSegment in cppjieba is not provided") + "the dict of MPSegment in cppjieba is not provided.") kwargs["mp_path"] = mp_path if model is not None: kwargs["model"] = model @@ -160,7 +160,7 @@ def check_jieba_add_word(method): if "freq" in kwargs: freq = kwargs.get("freq") if word is None: - raise ValueError("word is not provided") + raise ValueError("word is not provided.") kwargs["word"] = word if freq is not None: check_uint32(freq) @@ -179,7 +179,7 @@ def check_jieba_add_dict(method): if "user_dict" in kwargs: user_dict = kwargs.get("user_dict") if user_dict is None: - raise ValueError("user_dict is not provided") + raise ValueError("user_dict is not provided.") kwargs["user_dict"] = user_dict return method(self, **kwargs) @@ -187,7 +187,7 @@ def check_jieba_add_dict(method): def check_from_dataset(method): - """A wrapper that wrap a parameter checker to the original function(crop operation).""" + """A wrapper that wrap a parameter checker to the original function.""" # def from_dataset(cls, dataset, columns, freq_range=None, top_k=None): @wraps(method) @@ -210,27 +210,27 @@ def check_from_dataset(method): for column in columns: if not isinstance(column, str): - raise ValueError("columns need to be a list of strings") + raise ValueError("columns need to be a list of strings.") if freq_range is None: freq_range = (None, None) if not isinstance(freq_range, tuple) or len(freq_range) != 2: - raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None") + raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.") for num in freq_range: if num is not None and (not isinstance(num, int)): - raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None") + raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.") if isinstance(freq_range[0], int) and isinstance(freq_range[1], int): if freq_range[0] > freq_range[1] or freq_range[0] < 0: - raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)") + raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).") if top_k is not None and (not isinstance(top_k, int)): - raise ValueError("top_k needs to be a positive integer") + raise ValueError("top_k needs to be a positive integer.") if isinstance(top_k, int) and top_k <= 0: - raise ValueError("top_k needs to be a positive integer") + raise ValueError("top_k needs to be a positive integer.") kwargs["dataset"] = dataset kwargs["columns"] = columns @@ -243,7 +243,7 @@ def check_from_dataset(method): def check_ngram(method): - """A wrapper that wrap a parameter checker to the original function(crop operation).""" + """A wrapper that wrap a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -261,11 +261,11 @@ def check_ngram(method): n = [n] if not (isinstance(n, list) and n != []): - raise ValueError("n needs to be a non-empty list of positive integers") + raise ValueError("n needs to be a non-empty list of positive integers.") for gram in n: if not (isinstance(gram, int) and gram > 0): - raise ValueError("n in ngram needs to be a positive number\n") + raise ValueError("n in ngram needs to be a positive number.") if left_pad is None: left_pad = ("", 0) @@ -275,20 +275,20 @@ def check_ngram(method): if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance( left_pad[1], int)): - raise ValueError("left_pad needs to be a tuple of (str, int) str is pad token and int is pad_width") + raise ValueError("left_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.") if not (isinstance(right_pad, tuple) and len(right_pad) == 2 and isinstance(right_pad[0], str) and isinstance( right_pad[1], int)): - raise ValueError("right_pad needs to be a tuple of (str, int) str is pad token and int is pad_width") + raise ValueError("right_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.") if not (left_pad[1] >= 0 and right_pad[1] >= 0): - raise ValueError("padding width need to be positive numbers") + raise ValueError("padding width need to be positive numbers.") if separator is None: separator = " " if not isinstance(separator, str): - raise ValueError("separator needs to be a string") + raise ValueError("separator needs to be a string.") kwargs["n"] = n kwargs["left_pad"] = left_pad diff --git a/tests/ut/python/dataset/test_ngram_op.py b/tests/ut/python/dataset/test_ngram_op.py index f2da1fb863b..73b2702378e 100644 --- a/tests/ut/python/dataset/test_ngram_op.py +++ b/tests/ut/python/dataset/test_ngram_op.py @@ -16,7 +16,7 @@ Testing Ngram in mindspore.dataset """ import mindspore.dataset as ds -import mindspore.dataset.text as nlp +import mindspore.dataset.text as text import numpy as np @@ -39,7 +39,7 @@ def test_multiple_ngrams(): yield (np.array(line.split(" "), dtype='S'),) dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"]) - dataset = dataset.map(input_columns=["text"], operations=nlp.Ngram([1, 2, 3], ("_", 2), ("_", 2), " ")) + dataset = dataset.map(input_columns=["text"], operations=text.Ngram([1, 2, 3], ("_", 2), ("_", 2), " ")) i = 0 for data in dataset.create_dict_iterator(): @@ -61,7 +61,7 @@ def test_simple_ngram(): yield (np.array(line.split(" "), dtype='S'),) dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"]) - dataset = dataset.map(input_columns=["text"], operations=nlp.Ngram(3, separator=None)) + dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=None)) i = 0 for data in dataset.create_dict_iterator(): @@ -73,11 +73,11 @@ def test_corner_cases(): """ testing various corner cases and exceptions""" def test_config(input_line, output_line, n, l_pad=None, r_pad=None, sep=None): - def gen(text): - yield (np.array(text.split(" "), dtype='S'),) + def gen(texts): + yield (np.array(texts.split(" "), dtype='S'),) dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"]) - dataset = dataset.map(input_columns=["text"], operations=nlp.Ngram(n, l_pad, r_pad, separator=sep)) + dataset = dataset.map(input_columns=["text"], operations=text.Ngram(n, l_pad, r_pad, separator=sep)) for data in dataset.create_dict_iterator(): assert [d.decode("utf8") for d in data["text"]] == output_line, output_line