!1365 Clean up work for text python sub-package

Merge pull request !1365 from h.farahat/text_namespace
This commit is contained in:
mindspore-ci-bot 2020-05-22 23:16:31 +08:00 committed by Gitee
commit 458436186c
33 changed files with 273 additions and 246 deletions

View File

@ -52,7 +52,7 @@ add_subdirectory(core)
add_subdirectory(kernels)
add_subdirectory(engine)
add_subdirectory(api)
add_subdirectory(nlp)
add_subdirectory(text)
######################################################################
################### Create _c_dataengine Library ######################
@ -62,7 +62,6 @@ set(submodules
$<TARGET_OBJECTS:kernels>
$<TARGET_OBJECTS:kernels-image>
$<TARGET_OBJECTS:kernels-data>
$<TARGET_OBJECTS:kernels-text>
$<TARGET_OBJECTS:APItoPython>
$<TARGET_OBJECTS:engine-datasetops-source>
$<TARGET_OBJECTS:engine-datasetops-source-sampler>
@ -70,8 +69,8 @@ set(submodules
$<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine>
$<TARGET_OBJECTS:nlp>
$<TARGET_OBJECTS:nlp-kernels>
$<TARGET_OBJECTS:text>
$<TARGET_OBJECTS:text-kernels>
)
if (ENABLE_TDTQUE)

View File

@ -38,10 +38,6 @@
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/kernels/text/jieba_tokenizer_op.h"
#include "dataset/kernels/text/unicode_char_tokenizer_op.h"
#include "dataset/nlp/vocab.h"
#include "dataset/nlp/kernels/lookup_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
@ -63,6 +59,10 @@
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/gnn/graph.h"
#include "dataset/kernels/data/to_float16_op.h"
#include "dataset/text/kernels/jieba_tokenizer_op.h"
#include "dataset/text/kernels/unicode_char_tokenizer_op.h"
#include "dataset/text/vocab.h"
#include "dataset/text/kernels/lookup_op.h"
#include "dataset/util/random.h"
#include "mindrecord/include/shard_operator.h"
#include "mindrecord/include/shard_pk_sample.h"
@ -577,9 +577,9 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("TEXTFILE", OpName::kTextFile);
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
.value("DE_INTER_JIEBA_MIX", JiebaMode::kMix)
.value("DE_INTER_JIEBA_MP", JiebaMode::kMp)
.value("DE_INTER_JIEBA_HMM", JiebaMode::kHmm)
.value("DE_JIEBA_MIX", JiebaMode::kMix)
.value("DE_JIEBA_MP", JiebaMode::kMp)
.value("DE_JIEBA_HMM", JiebaMode::kHmm)
.export_values();
(void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic())

View File

@ -2,7 +2,6 @@ add_subdirectory(image)
add_subdirectory(data)
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_subdirectory(text)
add_library(kernels OBJECT
py_func_op.cc
tensor_op.cc)

View File

@ -1,7 +0,0 @@
add_subdirectory(kernels)
add_library(nlp OBJECT
vocab.cc
)
add_dependencies(nlp nlp-kernels)

View File

@ -1,3 +0,0 @@
add_library(nlp-kernels OBJECT
lookup_op.cc
)

View File

@ -0,0 +1,7 @@
add_subdirectory(kernels)
add_library(text OBJECT
vocab.cc
)
add_dependencies(text text-kernels)

View File

@ -1,6 +1,7 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(kernels-text OBJECT
jieba_tokenizer_op.cc
unicode_char_tokenizer_op.cc
)
add_library(text-kernels OBJECT
lookup_op.cc
jieba_tokenizer_op.cc
unicode_char_tokenizer_op.cc
)

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/text/jieba_tokenizer_op.h"
#include "dataset/text/kernels/jieba_tokenizer_op.h"
#include <vector>
#include <memory>

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/nlp/kernels/lookup_op.h"
#include "dataset/text/kernels/lookup_op.h"
#include <string>

View File

@ -24,7 +24,7 @@
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/util/status.h"
#include "dataset/nlp/vocab.h"
#include "dataset/text/vocab.h"
namespace mindspore {
namespace dataset {

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/text/unicode_char_tokenizer_op.h"
#include "dataset/text/kernels/unicode_char_tokenizer_op.h"
#include <memory>
#include <string>
#include <string_view>

View File

@ -17,7 +17,7 @@
#include <map>
#include <utility>
#include "dataset/nlp/vocab.h"
#include "dataset/text/vocab.h"
namespace mindspore {
namespace dataset {

View File

@ -284,10 +284,10 @@ class Dataset:
Examples:
>>> import mindspore.dataset as ds
>>> import mindspore.dataset.transforms.text.utils as text
>>> import mindspore.dataset.text as text
>>> # declare a function which returns a Dataset object
>>> def flat_map_func(x):
>>> data_dir = text.as_text(x[0])
>>> data_dir = text.to_str(x[0])
>>> d = ds.ImageFolderDatasetV2(data_dir)
>>> return d
>>> # data is a Dataset object

View File

@ -15,5 +15,5 @@
"""
mindspore.dataset.text
"""
from .c_transforms import *
from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer
from .utils import to_str, to_bytes, JiebaMode, Vocab

View File

@ -11,20 +11,40 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module c_transforms provides common nlp operations.
c transforms for all text related operators
"""
import os
import re
import mindspore._c_dataengine as cde
from .utils import JiebaMode
from .validators import check_jieba_add_dict, check_jieba_add_word, check_jieba_init
from .validators import check_lookup, check_jieba_add_dict, \
check_jieba_add_word, check_jieba_init
class Lookup(cde.LookupOp):
"""
Lookup operator that looks up a word to an id
Args:
vocab(Vocab): a Vocab object
unknown(None,int): default id to lookup a word that is out of vocab
"""
@check_lookup
def __init__(self, vocab, unknown=None):
if unknown is None:
super().__init__(vocab)
else:
super().__init__(vocab, unknown)
DE_C_INTER_JIEBA_MODE = {
JiebaMode.MIX: cde.JiebaMode.DE_INTER_JIEBA_MIX,
JiebaMode.MP: cde.JiebaMode.DE_INTER_JIEBA_MP,
JiebaMode.HMM: cde.JiebaMode.DE_INTER_JIEBA_HMM
JiebaMode.MIX: cde.JiebaMode.DE_JIEBA_MIX,
JiebaMode.MP: cde.JiebaMode.DE_JIEBA_MP,
JiebaMode.HMM: cde.JiebaMode.DE_JIEBA_HMM
}
@ -41,6 +61,7 @@ class JiebaTokenizer(cde.JiebaTokenizerOp):
"HMM" mode will tokenize with Hiddel Markov Model Segment algorithm,
"MIX" model will tokenize with a mix of MPSegment and HMMSegment algorithm.
"""
@check_jieba_init
def __init__(self, hmm_path, mp_path, mode=JiebaMode.MIX):
self.mode = mode

View File

@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
c transforms for all text related operators
Some basic function for nlp
"""
from enum import IntEnum
import mindspore._c_dataengine as cde
from .validators import check_lookup, check_from_list, check_from_dict, check_from_file
import numpy as np
from .validators import check_from_file, check_from_list, check_from_dict
class Vocab(cde.Vocab):
@ -61,17 +64,43 @@ class Vocab(cde.Vocab):
return super().from_dict(word_dict)
class Lookup(cde.LookupOp):
def to_str(array, encoding='utf8'):
"""
Lookup operator that looks up a word to an id
Convert numpy array of `bytes` to array of `str` by decoding each element based on charset `encoding`.
Args:
vocab(Vocab): a Vocab object
unknown(None,int): default id to lookup a word that is out of vocab
array (numpy array): Array of type `bytes` representing strings.
encoding (string): Indicating the charset for decoding.
Returns:
Numpy array of `str`.
"""
@check_lookup
def __init__(self, vocab, unknown=None):
if unknown is None:
super().__init__(vocab)
else:
super().__init__(vocab, unknown)
if not isinstance(array, np.ndarray):
raise ValueError('input should be a numpy array')
return np.char.decode(array, encoding)
def to_bytes(array, encoding='utf8'):
"""
Convert numpy array of `str` to array of `bytes` by encoding each element based on charset `encoding`.
Args:
array (numpy array): Array of type `str` representing strings.
encoding (string): Indicating the charset for encoding.
Returns:
Numpy array of `bytes`.
"""
if not isinstance(array, np.ndarray):
raise ValueError('input should be a numpy array')
return np.char.encode(array, encoding)
class JiebaMode(IntEnum):
MIX = 0
MP = 1
HMM = 2

View File

@ -17,8 +17,11 @@ validators for text ops
"""
from functools import wraps
import mindspore._c_dataengine as cde
from ..transforms.validators import check_uint32
def check_lookup(method):
"""A wrapper that wrap a parameter checker to the original function(crop operation)."""
@ -106,3 +109,67 @@ def check_from_dict(method):
return method(self, **kwargs)
return new_method
def check_jieba_init(method):
"""Wrapper method to check the parameters of jieba add word."""
@wraps(method)
def new_method(self, *args, **kwargs):
hmm_path, mp_path, model = (list(args) + 3 * [None])[:3]
if "hmm_path" in kwargs:
hmm_path = kwargs.get("hmm_path")
if "mp_path" in kwargs:
mp_path = kwargs.get("mp_path")
if hmm_path is None:
raise ValueError(
"the dict of HMMSegment in cppjieba is not provided")
kwargs["hmm_path"] = hmm_path
if mp_path is None:
raise ValueError(
"the dict of MPSegment in cppjieba is not provided")
kwargs["mp_path"] = mp_path
if model is not None:
kwargs["model"] = model
return method(self, **kwargs)
return new_method
def check_jieba_add_word(method):
"""Wrapper method to check the parameters of jieba add word."""
@wraps(method)
def new_method(self, *args, **kwargs):
word, freq = (list(args) + 2 * [None])[:2]
if "word" in kwargs:
word = kwargs.get("word")
if "freq" in kwargs:
freq = kwargs.get("freq")
if word is None:
raise ValueError("word is not provided")
kwargs["word"] = word
if freq is not None:
check_uint32(freq)
kwargs["freq"] = freq
return method(self, **kwargs)
return new_method
def check_jieba_add_dict(method):
"""Wrapper method to check the parameters of add dict"""
@wraps(method)
def new_method(self, *args, **kwargs):
user_dict = (list(args) + [None])[0]
if "user_dict" in kwargs:
user_dict = kwargs.get("user_dict")
if user_dict is None:
raise ValueError("user_dict is not provided")
kwargs["user_dict"] = user_dict
return method(self, **kwargs)
return new_method

View File

@ -1,21 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module is to support nlp augmentations. It includes two parts:
c_transforms and py_transforms. C_transforms is a high performance
image augmentation module which is developed with c++ opencv. Py_transforms
provide more kinds of image augmentations which is developed with python PIL.
"""
from .utils import as_text, JiebaMode
from . import c_transforms

View File

@ -1,43 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Some basic function for nlp
"""
from enum import IntEnum
import numpy as np
def as_text(array, encoding='utf8'):
"""
Convert data of array to unicode.
Args:
array (numpy array): Data of array should be ASCII values of each character after converted.
encoding (string): Indicating the charset for decoding.
Returns:
A 'str' object.
"""
if not isinstance(array, np.ndarray):
raise ValueError('input should be a numpy array')
decode = np.vectorize(lambda x: x.decode(encoding))
return decode(array)
class JiebaMode(IntEnum):
MIX = 0
MP = 1
HMM = 2

View File

@ -1,79 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Validators for TensorOps.
"""
from functools import wraps
from ...transforms.validators import check_uint32
def check_jieba_init(method):
"""Wrapper method to check the parameters of jieba add word."""
@wraps(method)
def new_method(self, *args, **kwargs):
hmm_path, mp_path, model = (list(args) + 3 * [None])[:3]
if "hmm_path" in kwargs:
hmm_path = kwargs.get("hmm_path")
if "mp_path" in kwargs:
mp_path = kwargs.get("mp_path")
if hmm_path is None:
raise ValueError(
"the dict of HMMSegment in cppjieba is not provided")
kwargs["hmm_path"] = hmm_path
if mp_path is None:
raise ValueError(
"the dict of MPSegment in cppjieba is not provided")
kwargs["mp_path"] = mp_path
if model is not None:
kwargs["model"] = model
return method(self, **kwargs)
return new_method
def check_jieba_add_word(method):
"""Wrapper method to check the parameters of jieba add word."""
@wraps(method)
def new_method(self, *args, **kwargs):
word, freq = (list(args) + 2 * [None])[:2]
if "word" in kwargs:
word = kwargs.get("word")
if "freq" in kwargs:
freq = kwargs.get("freq")
if word is None:
raise ValueError("word is not provided")
kwargs["word"] = word
if freq is not None:
check_uint32(freq)
kwargs["freq"] = freq
return method(self, **kwargs)
return new_method
def check_jieba_add_dict(method):
"""Wrapper method to check the parameters of add dict"""
@wraps(method)
def new_method(self, *args, **kwargs):
user_dict = (list(args) + [None])[0]
if "user_dict" in kwargs:
user_dict = kwargs.get("user_dict")
if user_dict is None:
raise ValueError("user_dict is not provided")
kwargs["user_dict"] = user_dict
return method(self, **kwargs)
return new_method

View File

@ -18,7 +18,7 @@
#include <string_view>
#include "common/common.h"
#include "dataset/kernels/text/jieba_tokenizer_op.h"
#include "dataset/text/kernels/jieba_tokenizer_op.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"

View File

@ -18,7 +18,7 @@
#include <string_view>
#include "common/common.h"
#include "dataset/kernels/text/unicode_char_tokenizer_op.h"
#include "dataset/text/kernels/unicode_char_tokenizer_op.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"

View File

@ -0,0 +1,2 @@
home is behind the world ahead
is behind home ahead world the

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
import mindspore.dataset.transforms.text.utils as nlp
from mindspore import log as logger
DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"

View File

@ -24,7 +24,6 @@ def test_flat_map_1():
'''
DATA_FILE records the path of image folders, load the images from them.
'''
import mindspore.dataset.transforms.text.utils as nlp
def flat_map_func(x):
data_dir = x[0].item().decode('utf8')
@ -45,7 +44,6 @@ def test_flat_map_2():
'''
Flatten 3D structure data
'''
import mindspore.dataset.transforms.text.utils as nlp
def flat_map_func_1(x):
data_dir = x[0].item().decode('utf8')

View File

@ -27,7 +27,7 @@ import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
from mindspore.dataset.transforms.vision import Inter
from mindspore.dataset.transforms.text import as_text
from mindspore.dataset.text import to_str
from mindspore.mindrecord import FileWriter
FILES_NUM = 4
@ -73,7 +73,7 @@ def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file):
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format(as_text(item["file_name"])))
{}------------------------".format(to_str(item["file_name"])))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
@ -93,7 +93,7 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
logger.info("-------------- item[data]: \
{}------------------------".format(item["data"][:10]))
logger.info("-------------- item[file_name]: \
{}------------------------".format(as_text(item["file_name"])))
{}------------------------".format(to_str(item["file_name"])))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
@ -111,7 +111,7 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format(as_text(item["file_name"])))
{}------------------------".format(to_str(item["file_name"])))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
@ -128,7 +128,7 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file):
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format(as_text(item["file_name"])))
{}------------------------".format(to_str(item["file_name"])))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1

View File

@ -0,0 +1,46 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.text as text
# this file contains "home is behind the world head" each word is 1 line
DATA_FILE = "../data/dataset/testVocab/words.txt"
VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt"
HMM_FILE = "../data/dataset/jiebadict/hmm_model.utf8"
MP_FILE = "../data/dataset/jiebadict/jieba.dict.utf8"
def test_on_tokenized_line():
data = ds.TextFileDataset("../data/dataset/testVocab/lines.txt", shuffle=False)
jieba_op = text.JiebaTokenizer(HMM_FILE, MP_FILE, mode=text.JiebaMode.MP)
with open(VOCAB_FILE, 'r') as f:
for line in f:
word = line.split(',')[0]
jieba_op.add_word(word)
data = data.map(input_columns=["text"], operations=jieba_op)
vocab = text.Vocab.from_file(VOCAB_FILE, ",")
lookup = text.Lookup(vocab)
data = data.map(input_columns=["text"], operations=lookup)
res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14],
[11, 1, 12, 1, 10, 1, 14, 1, 13, 1, 15]], dtype=np.int32)
for i, d in enumerate(data.create_dict_iterator()):
np.testing.assert_array_equal(d["text"], res[i]), i
if __name__ == '__main__':
test_on_tokenized_line()

View File

@ -14,8 +14,8 @@
# ==============================================================================
import numpy as np
import mindspore.dataset as ds
from mindspore.dataset.transforms.text.c_transforms import JiebaTokenizer
from mindspore.dataset.transforms.text.utils import JiebaMode, as_text
from mindspore.dataset.text import JiebaTokenizer
from mindspore.dataset.text import JiebaMode, to_str
DATA_FILE = "../data/dataset/testJiebaDataset/3.txt"
DATA_ALL_FILE = "../data/dataset/testJiebaDataset/*"
@ -33,7 +33,7 @@ def test_jieba_1():
expect = ['今天天气', '太好了', '我们', '一起', '', '外面', '玩吧']
ret = []
for i in data.create_dict_iterator():
ret = as_text(i["text"])
ret = to_str(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
@ -46,7 +46,7 @@ def test_jieba_1_1():
operations=jieba_op, num_parallel_workers=1)
expect = ['今天', '天气', '', '', '', '我们', '一起', '', '外面', '', '']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
ret = to_str(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
@ -59,7 +59,7 @@ def test_jieba_1_2():
operations=jieba_op, num_parallel_workers=1)
expect = ['今天天气', '太好了', '我们', '一起', '', '外面', '玩吧']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
ret = to_str(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
@ -74,7 +74,7 @@ def test_jieba_2():
data = data.map(input_columns=["text"],
operations=jieba_op, num_parallel_workers=2)
for i in data.create_dict_iterator():
ret = as_text(i["text"])
ret = to_str(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
@ -89,7 +89,7 @@ def test_jieba_2_1():
operations=jieba_op, num_parallel_workers=2)
expect = ['男默女泪', '', '长江大桥']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
ret = to_str(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
@ -113,7 +113,7 @@ def test_jieba_2_3():
operations=jieba_op, num_parallel_workers=2)
expect = ['江州', '市长', '江大桥', '参加', '', '长江大桥', '', '通车', '仪式']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
ret = to_str(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
@ -131,7 +131,7 @@ def test_jieba_3():
operations=jieba_op, num_parallel_workers=1)
expect = ['男默女泪', '', '长江大桥']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
ret = to_str(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
@ -150,7 +150,7 @@ def test_jieba_3_1():
operations=jieba_op, num_parallel_workers=1)
expect = ['男默女泪', '市长', '江大桥']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
ret = to_str(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
@ -166,7 +166,7 @@ def test_jieba_4():
operations=jieba_op, num_parallel_workers=1)
expect = ['今天天气', '太好了', '我们', '一起', '', '外面', '玩吧']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
ret = to_str(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
@ -192,7 +192,7 @@ def test_jieba_5():
operations=jieba_op, num_parallel_workers=1)
expect = ['江州', '市长', '江大桥', '参加', '', '长江大桥', '', '通车', '仪式']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
ret = to_str(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]
@ -203,7 +203,7 @@ def gen():
def pytoken_op(input_data):
te = str(as_text(input_data))
te = str(to_str(input_data))
tokens = []
tokens.append(te[:5].encode("UTF8"))
tokens.append(te[5:10].encode("UTF8"))
@ -217,7 +217,7 @@ def test_jieba_6():
operations=pytoken_op, num_parallel_workers=1)
expect = ['今天天气太', '好了我们一', '起去外面玩吧']
for i in data.create_dict_iterator():
ret = as_text(i["text"])
ret = to_str(i["text"])
for index, item in enumerate(ret):
assert item == expect[index]

View File

@ -16,6 +16,8 @@ import mindspore._c_dataengine as cde
import numpy as np
import pytest
from mindspore.dataset.text import to_str, to_bytes
import mindspore.dataset as ds
import mindspore.common.dtype as mstype
@ -65,7 +67,8 @@ def test_map():
data = ds.GeneratorDataset(gen, column_names=["col"])
def split(b):
splits = b.item().decode("utf8").split()
s = to_str(b)
splits = s.item().split()
return np.array(splits, dtype='S')
data = data.map(input_columns=["col"], operations=split)
@ -74,11 +77,20 @@ def test_map():
np.testing.assert_array_equal(d[0], expected)
def as_str(arr):
def decode(s): return s.decode("utf8")
def test_map2():
def gen():
yield np.array(["ab cde 121"], dtype='S'),
decode_v = np.vectorize(decode)
return decode_v(arr)
data = ds.GeneratorDataset(gen, column_names=["col"])
def upper(b):
out = np.char.upper(b)
return out
data = data.map(input_columns=["col"], operations=upper)
expected = np.array(["AB CDE 121"], dtype='S')
for d in data:
np.testing.assert_array_equal(d[0], expected)
line = np.array(["This is a text file.",
@ -106,9 +118,9 @@ def test_tfrecord1():
assert d["line"].shape == line[i].shape
assert d["words"].shape == words[i].shape
assert d["chinese"].shape == chinese[i].shape
np.testing.assert_array_equal(line[i], as_str(d["line"]))
np.testing.assert_array_equal(words[i], as_str(d["words"]))
np.testing.assert_array_equal(chinese[i], as_str(d["chinese"]))
np.testing.assert_array_equal(line[i], to_str(d["line"]))
np.testing.assert_array_equal(words[i], to_str(d["words"]))
np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
def test_tfrecord2():
@ -118,9 +130,9 @@ def test_tfrecord2():
assert d["line"].shape == line[i].shape
assert d["words"].shape == words[i].shape
assert d["chinese"].shape == chinese[i].shape
np.testing.assert_array_equal(line[i], as_str(d["line"]))
np.testing.assert_array_equal(words[i], as_str(d["words"]))
np.testing.assert_array_equal(chinese[i], as_str(d["chinese"]))
np.testing.assert_array_equal(line[i], to_str(d["line"]))
np.testing.assert_array_equal(words[i], to_str(d["words"]))
np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
def test_tfrecord3():
@ -135,9 +147,9 @@ def test_tfrecord3():
assert d["line"].shape == line[i].shape
assert d["words"].shape == words[i].reshape([2, 2]).shape
assert d["chinese"].shape == chinese[i].shape
np.testing.assert_array_equal(line[i], as_str(d["line"]))
np.testing.assert_array_equal(words[i].reshape([2, 2]), as_str(d["words"]))
np.testing.assert_array_equal(chinese[i], as_str(d["chinese"]))
np.testing.assert_array_equal(line[i], to_str(d["line"]))
np.testing.assert_array_equal(words[i].reshape([2, 2]), to_str(d["words"]))
np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
def create_text_mindrecord():
@ -167,16 +179,17 @@ def test_mindrecord():
for i, d in enumerate(data.create_dict_iterator()):
assert d["english"].shape == line[i].shape
assert d["chinese"].shape == chinese[i].shape
np.testing.assert_array_equal(line[i], as_str(d["english"]))
np.testing.assert_array_equal(chinese[i], as_str(d["chinese"]))
np.testing.assert_array_equal(line[i], to_str(d["english"]))
np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
if __name__ == '__main__':
# test_generator()
# test_basic()
# test_batching_strings()
test_generator()
test_basic()
test_batching_strings()
test_map()
# test_tfrecord1()
# test_tfrecord2()
# test_tfrecord3()
# test_mindrecord()
test_map2()
test_tfrecord1()
test_tfrecord2()
test_tfrecord3()
test_mindrecord()

View File

@ -17,8 +17,7 @@ Testing UnicodeCharTokenizer op in DE
"""
import mindspore.dataset as ds
from mindspore import log as logger
import mindspore.dataset.transforms.text.c_transforms as nlp
import mindspore.dataset.transforms.text.utils as nlp_util
import mindspore.dataset.text as nlp
DATA_FILE = "../data/dataset/testTokenizerData/1.txt"
@ -43,7 +42,7 @@ def test_unicode_char_tokenizer():
dataset = dataset.map(operations=tokenizer)
tokens = []
for i in dataset.create_dict_iterator():
text = nlp_util.as_text(i['text']).tolist()
text = nlp.to_str(i['text']).tolist()
tokens.append(text)
logger.info("The out tokens is : {}".format(tokens))
assert split_by_unicode_char(input_strs) == tokens