forked from mindspore-Ecosystem/mindspore
Clean up work for text python package
This commit is contained in:
parent
df361d1d26
commit
6c21e556c4
|
@ -52,7 +52,7 @@ add_subdirectory(core)
|
||||||
add_subdirectory(kernels)
|
add_subdirectory(kernels)
|
||||||
add_subdirectory(engine)
|
add_subdirectory(engine)
|
||||||
add_subdirectory(api)
|
add_subdirectory(api)
|
||||||
add_subdirectory(nlp)
|
add_subdirectory(text)
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
################### Create _c_dataengine Library ######################
|
################### Create _c_dataengine Library ######################
|
||||||
|
@ -62,15 +62,14 @@ set(submodules
|
||||||
$<TARGET_OBJECTS:kernels>
|
$<TARGET_OBJECTS:kernels>
|
||||||
$<TARGET_OBJECTS:kernels-image>
|
$<TARGET_OBJECTS:kernels-image>
|
||||||
$<TARGET_OBJECTS:kernels-data>
|
$<TARGET_OBJECTS:kernels-data>
|
||||||
$<TARGET_OBJECTS:kernels-text>
|
|
||||||
$<TARGET_OBJECTS:APItoPython>
|
$<TARGET_OBJECTS:APItoPython>
|
||||||
$<TARGET_OBJECTS:engine-datasetops-source>
|
$<TARGET_OBJECTS:engine-datasetops-source>
|
||||||
$<TARGET_OBJECTS:engine-datasetops-source-sampler>
|
$<TARGET_OBJECTS:engine-datasetops-source-sampler>
|
||||||
$<TARGET_OBJECTS:engine-datasetops>
|
$<TARGET_OBJECTS:engine-datasetops>
|
||||||
$<TARGET_OBJECTS:engine-opt>
|
$<TARGET_OBJECTS:engine-opt>
|
||||||
$<TARGET_OBJECTS:engine>
|
$<TARGET_OBJECTS:engine>
|
||||||
$<TARGET_OBJECTS:nlp>
|
$<TARGET_OBJECTS:text>
|
||||||
$<TARGET_OBJECTS:nlp-kernels>
|
$<TARGET_OBJECTS:text-kernels>
|
||||||
)
|
)
|
||||||
|
|
||||||
if (ENABLE_TDTQUE)
|
if (ENABLE_TDTQUE)
|
||||||
|
|
|
@ -38,10 +38,6 @@
|
||||||
#include "dataset/kernels/image/resize_op.h"
|
#include "dataset/kernels/image/resize_op.h"
|
||||||
#include "dataset/kernels/image/uniform_aug_op.h"
|
#include "dataset/kernels/image/uniform_aug_op.h"
|
||||||
#include "dataset/kernels/data/type_cast_op.h"
|
#include "dataset/kernels/data/type_cast_op.h"
|
||||||
#include "dataset/kernels/text/jieba_tokenizer_op.h"
|
|
||||||
#include "dataset/kernels/text/unicode_char_tokenizer_op.h"
|
|
||||||
#include "dataset/nlp/vocab.h"
|
|
||||||
#include "dataset/nlp/kernels/lookup_op.h"
|
|
||||||
#include "dataset/engine/datasetops/source/cifar_op.h"
|
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||||
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||||
#include "dataset/engine/datasetops/source/io_block.h"
|
#include "dataset/engine/datasetops/source/io_block.h"
|
||||||
|
@ -62,6 +58,10 @@
|
||||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||||
#include "dataset/engine/datasetops/source/voc_op.h"
|
#include "dataset/engine/datasetops/source/voc_op.h"
|
||||||
#include "dataset/kernels/data/to_float16_op.h"
|
#include "dataset/kernels/data/to_float16_op.h"
|
||||||
|
#include "dataset/text/kernels/jieba_tokenizer_op.h"
|
||||||
|
#include "dataset/text/kernels/unicode_char_tokenizer_op.h"
|
||||||
|
#include "dataset/text/vocab.h"
|
||||||
|
#include "dataset/text/kernels/lookup_op.h"
|
||||||
#include "dataset/util/random.h"
|
#include "dataset/util/random.h"
|
||||||
#include "mindrecord/include/shard_operator.h"
|
#include "mindrecord/include/shard_operator.h"
|
||||||
#include "mindrecord/include/shard_pk_sample.h"
|
#include "mindrecord/include/shard_pk_sample.h"
|
||||||
|
@ -549,9 +549,9 @@ PYBIND11_MODULE(_c_dataengine, m) {
|
||||||
.value("TEXTFILE", OpName::kTextFile);
|
.value("TEXTFILE", OpName::kTextFile);
|
||||||
|
|
||||||
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
|
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
|
||||||
.value("DE_INTER_JIEBA_MIX", JiebaMode::kMix)
|
.value("DE_JIEBA_MIX", JiebaMode::kMix)
|
||||||
.value("DE_INTER_JIEBA_MP", JiebaMode::kMp)
|
.value("DE_JIEBA_MP", JiebaMode::kMp)
|
||||||
.value("DE_INTER_JIEBA_HMM", JiebaMode::kHmm)
|
.value("DE_JIEBA_HMM", JiebaMode::kHmm)
|
||||||
.export_values();
|
.export_values();
|
||||||
|
|
||||||
(void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic())
|
(void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic())
|
||||||
|
|
|
@ -2,7 +2,6 @@ add_subdirectory(image)
|
||||||
add_subdirectory(data)
|
add_subdirectory(data)
|
||||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||||
add_subdirectory(text)
|
|
||||||
add_library(kernels OBJECT
|
add_library(kernels OBJECT
|
||||||
py_func_op.cc
|
py_func_op.cc
|
||||||
tensor_op.cc)
|
tensor_op.cc)
|
||||||
|
|
|
@ -1,7 +0,0 @@
|
||||||
add_subdirectory(kernels)
|
|
||||||
|
|
||||||
add_library(nlp OBJECT
|
|
||||||
vocab.cc
|
|
||||||
)
|
|
||||||
|
|
||||||
add_dependencies(nlp nlp-kernels)
|
|
|
@ -1,3 +0,0 @@
|
||||||
add_library(nlp-kernels OBJECT
|
|
||||||
lookup_op.cc
|
|
||||||
)
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
add_subdirectory(kernels)
|
||||||
|
|
||||||
|
add_library(text OBJECT
|
||||||
|
vocab.cc
|
||||||
|
)
|
||||||
|
|
||||||
|
add_dependencies(text text-kernels)
|
|
@ -1,6 +1,7 @@
|
||||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||||
add_library(kernels-text OBJECT
|
add_library(text-kernels OBJECT
|
||||||
jieba_tokenizer_op.cc
|
lookup_op.cc
|
||||||
unicode_char_tokenizer_op.cc
|
jieba_tokenizer_op.cc
|
||||||
)
|
unicode_char_tokenizer_op.cc
|
||||||
|
)
|
|
@ -13,7 +13,7 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "dataset/kernels/text/jieba_tokenizer_op.h"
|
#include "dataset/text/kernels/jieba_tokenizer_op.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
|
@ -13,7 +13,7 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "dataset/nlp/kernels/lookup_op.h"
|
#include "dataset/text/kernels/lookup_op.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
#include "dataset/core/tensor.h"
|
#include "dataset/core/tensor.h"
|
||||||
#include "dataset/kernels/tensor_op.h"
|
#include "dataset/kernels/tensor_op.h"
|
||||||
#include "dataset/util/status.h"
|
#include "dataset/util/status.h"
|
||||||
#include "dataset/nlp/vocab.h"
|
#include "dataset/text/vocab.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
|
@ -13,7 +13,7 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "dataset/kernels/text/unicode_char_tokenizer_op.h"
|
#include "dataset/text/kernels/unicode_char_tokenizer_op.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
|
@ -17,7 +17,7 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "dataset/nlp/vocab.h"
|
#include "dataset/text/vocab.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
|
@ -284,10 +284,10 @@ class Dataset:
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore.dataset as ds
|
>>> import mindspore.dataset as ds
|
||||||
>>> import mindspore.dataset.transforms.text.utils as text
|
>>> import mindspore.dataset.text as text
|
||||||
>>> # declare a function which returns a Dataset object
|
>>> # declare a function which returns a Dataset object
|
||||||
>>> def flat_map_func(x):
|
>>> def flat_map_func(x):
|
||||||
>>> data_dir = text.as_text(x[0])
|
>>> data_dir = text.to_str(x[0])
|
||||||
>>> d = ds.ImageFolderDatasetV2(data_dir)
|
>>> d = ds.ImageFolderDatasetV2(data_dir)
|
||||||
>>> return d
|
>>> return d
|
||||||
>>> # data is a Dataset object
|
>>> # data is a Dataset object
|
||||||
|
|
|
@ -15,5 +15,5 @@
|
||||||
"""
|
"""
|
||||||
mindspore.dataset.text
|
mindspore.dataset.text
|
||||||
"""
|
"""
|
||||||
|
from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer
|
||||||
from .c_transforms import *
|
from .utils import to_str, to_bytes, JiebaMode, Vocab
|
||||||
|
|
|
@ -11,20 +11,40 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
|
||||||
"""
|
"""
|
||||||
This module c_transforms provides common nlp operations.
|
c transforms for all text related operators
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import mindspore._c_dataengine as cde
|
import mindspore._c_dataengine as cde
|
||||||
|
|
||||||
from .utils import JiebaMode
|
from .utils import JiebaMode
|
||||||
from .validators import check_jieba_add_dict, check_jieba_add_word, check_jieba_init
|
from .validators import check_lookup, check_jieba_add_dict, \
|
||||||
|
check_jieba_add_word, check_jieba_init
|
||||||
|
|
||||||
|
|
||||||
|
class Lookup(cde.LookupOp):
|
||||||
|
"""
|
||||||
|
Lookup operator that looks up a word to an id
|
||||||
|
Args:
|
||||||
|
vocab(Vocab): a Vocab object
|
||||||
|
unknown(None,int): default id to lookup a word that is out of vocab
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_lookup
|
||||||
|
def __init__(self, vocab, unknown=None):
|
||||||
|
if unknown is None:
|
||||||
|
super().__init__(vocab)
|
||||||
|
else:
|
||||||
|
super().__init__(vocab, unknown)
|
||||||
|
|
||||||
|
|
||||||
DE_C_INTER_JIEBA_MODE = {
|
DE_C_INTER_JIEBA_MODE = {
|
||||||
JiebaMode.MIX: cde.JiebaMode.DE_INTER_JIEBA_MIX,
|
JiebaMode.MIX: cde.JiebaMode.DE_JIEBA_MIX,
|
||||||
JiebaMode.MP: cde.JiebaMode.DE_INTER_JIEBA_MP,
|
JiebaMode.MP: cde.JiebaMode.DE_JIEBA_MP,
|
||||||
JiebaMode.HMM: cde.JiebaMode.DE_INTER_JIEBA_HMM
|
JiebaMode.HMM: cde.JiebaMode.DE_JIEBA_HMM
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,6 +61,7 @@ class JiebaTokenizer(cde.JiebaTokenizerOp):
|
||||||
"HMM" mode will tokenize with Hiddel Markov Model Segment algorithm,
|
"HMM" mode will tokenize with Hiddel Markov Model Segment algorithm,
|
||||||
"MIX" model will tokenize with a mix of MPSegment and HMMSegment algorithm.
|
"MIX" model will tokenize with a mix of MPSegment and HMMSegment algorithm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@check_jieba_init
|
@check_jieba_init
|
||||||
def __init__(self, hmm_path, mp_path, mode=JiebaMode.MIX):
|
def __init__(self, hmm_path, mp_path, mode=JiebaMode.MIX):
|
||||||
self.mode = mode
|
self.mode = mode
|
|
@ -12,11 +12,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
c transforms for all text related operators
|
Some basic function for nlp
|
||||||
"""
|
"""
|
||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
import mindspore._c_dataengine as cde
|
import mindspore._c_dataengine as cde
|
||||||
from .validators import check_lookup, check_from_list, check_from_dict, check_from_file
|
import numpy as np
|
||||||
|
|
||||||
|
from .validators import check_from_file, check_from_list, check_from_dict
|
||||||
|
|
||||||
|
|
||||||
class Vocab(cde.Vocab):
|
class Vocab(cde.Vocab):
|
||||||
|
@ -61,17 +64,43 @@ class Vocab(cde.Vocab):
|
||||||
return super().from_dict(word_dict)
|
return super().from_dict(word_dict)
|
||||||
|
|
||||||
|
|
||||||
class Lookup(cde.LookupOp):
|
def to_str(array, encoding='utf8'):
|
||||||
"""
|
"""
|
||||||
Lookup operator that looks up a word to an id
|
Convert numpy array of `bytes` to array of `str` by decoding each element based on charset `encoding`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocab(Vocab): a Vocab object
|
array (numpy array): Array of type `bytes` representing strings.
|
||||||
unknown(None,int): default id to lookup a word that is out of vocab
|
encoding (string): Indicating the charset for decoding.
|
||||||
|
Returns:
|
||||||
|
Numpy array of `str`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@check_lookup
|
if not isinstance(array, np.ndarray):
|
||||||
def __init__(self, vocab, unknown=None):
|
raise ValueError('input should be a numpy array')
|
||||||
if unknown is None:
|
|
||||||
super().__init__(vocab)
|
return np.char.decode(array, encoding)
|
||||||
else:
|
|
||||||
super().__init__(vocab, unknown)
|
|
||||||
|
def to_bytes(array, encoding='utf8'):
|
||||||
|
"""
|
||||||
|
Convert numpy array of `str` to array of `bytes` by encoding each element based on charset `encoding`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
array (numpy array): Array of type `str` representing strings.
|
||||||
|
encoding (string): Indicating the charset for encoding.
|
||||||
|
Returns:
|
||||||
|
Numpy array of `bytes`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(array, np.ndarray):
|
||||||
|
raise ValueError('input should be a numpy array')
|
||||||
|
|
||||||
|
return np.char.encode(array, encoding)
|
||||||
|
|
||||||
|
|
||||||
|
class JiebaMode(IntEnum):
|
||||||
|
MIX = 0
|
||||||
|
MP = 1
|
||||||
|
HMM = 2
|
|
@ -17,8 +17,11 @@ validators for text ops
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
import mindspore._c_dataengine as cde
|
import mindspore._c_dataengine as cde
|
||||||
|
|
||||||
|
from ..transforms.validators import check_uint32
|
||||||
|
|
||||||
|
|
||||||
def check_lookup(method):
|
def check_lookup(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original function(crop operation)."""
|
"""A wrapper that wrap a parameter checker to the original function(crop operation)."""
|
||||||
|
@ -106,3 +109,67 @@ def check_from_dict(method):
|
||||||
return method(self, **kwargs)
|
return method(self, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_jieba_init(method):
|
||||||
|
"""Wrapper method to check the parameters of jieba add word."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
hmm_path, mp_path, model = (list(args) + 3 * [None])[:3]
|
||||||
|
|
||||||
|
if "hmm_path" in kwargs:
|
||||||
|
hmm_path = kwargs.get("hmm_path")
|
||||||
|
if "mp_path" in kwargs:
|
||||||
|
mp_path = kwargs.get("mp_path")
|
||||||
|
if hmm_path is None:
|
||||||
|
raise ValueError(
|
||||||
|
"the dict of HMMSegment in cppjieba is not provided")
|
||||||
|
kwargs["hmm_path"] = hmm_path
|
||||||
|
if mp_path is None:
|
||||||
|
raise ValueError(
|
||||||
|
"the dict of MPSegment in cppjieba is not provided")
|
||||||
|
kwargs["mp_path"] = mp_path
|
||||||
|
if model is not None:
|
||||||
|
kwargs["model"] = model
|
||||||
|
return method(self, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_jieba_add_word(method):
|
||||||
|
"""Wrapper method to check the parameters of jieba add word."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
word, freq = (list(args) + 2 * [None])[:2]
|
||||||
|
|
||||||
|
if "word" in kwargs:
|
||||||
|
word = kwargs.get("word")
|
||||||
|
if "freq" in kwargs:
|
||||||
|
freq = kwargs.get("freq")
|
||||||
|
if word is None:
|
||||||
|
raise ValueError("word is not provided")
|
||||||
|
kwargs["word"] = word
|
||||||
|
if freq is not None:
|
||||||
|
check_uint32(freq)
|
||||||
|
kwargs["freq"] = freq
|
||||||
|
return method(self, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_jieba_add_dict(method):
|
||||||
|
"""Wrapper method to check the parameters of add dict"""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
user_dict = (list(args) + [None])[0]
|
||||||
|
if "user_dict" in kwargs:
|
||||||
|
user_dict = kwargs.get("user_dict")
|
||||||
|
if user_dict is None:
|
||||||
|
raise ValueError("user_dict is not provided")
|
||||||
|
kwargs["user_dict"] = user_dict
|
||||||
|
return method(self, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
|
@ -1,21 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
This module is to support nlp augmentations. It includes two parts:
|
|
||||||
c_transforms and py_transforms. C_transforms is a high performance
|
|
||||||
image augmentation module which is developed with c++ opencv. Py_transforms
|
|
||||||
provide more kinds of image augmentations which is developed with python PIL.
|
|
||||||
"""
|
|
||||||
from .utils import as_text, JiebaMode
|
|
||||||
from . import c_transforms
|
|
|
@ -1,43 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
Some basic function for nlp
|
|
||||||
"""
|
|
||||||
from enum import IntEnum
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def as_text(array, encoding='utf8'):
|
|
||||||
"""
|
|
||||||
Convert data of array to unicode.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
array (numpy array): Data of array should be ASCII values of each character after converted.
|
|
||||||
encoding (string): Indicating the charset for decoding.
|
|
||||||
Returns:
|
|
||||||
A 'str' object.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not isinstance(array, np.ndarray):
|
|
||||||
raise ValueError('input should be a numpy array')
|
|
||||||
|
|
||||||
decode = np.vectorize(lambda x: x.decode(encoding))
|
|
||||||
return decode(array)
|
|
||||||
|
|
||||||
|
|
||||||
class JiebaMode(IntEnum):
|
|
||||||
MIX = 0
|
|
||||||
MP = 1
|
|
||||||
HMM = 2
|
|
|
@ -1,79 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Validators for TensorOps.
|
|
||||||
"""
|
|
||||||
from functools import wraps
|
|
||||||
from ...transforms.validators import check_uint32
|
|
||||||
|
|
||||||
|
|
||||||
def check_jieba_init(method):
|
|
||||||
"""Wrapper method to check the parameters of jieba add word."""
|
|
||||||
|
|
||||||
@wraps(method)
|
|
||||||
def new_method(self, *args, **kwargs):
|
|
||||||
hmm_path, mp_path, model = (list(args) + 3 * [None])[:3]
|
|
||||||
|
|
||||||
if "hmm_path" in kwargs:
|
|
||||||
hmm_path = kwargs.get("hmm_path")
|
|
||||||
if "mp_path" in kwargs:
|
|
||||||
mp_path = kwargs.get("mp_path")
|
|
||||||
if hmm_path is None:
|
|
||||||
raise ValueError(
|
|
||||||
"the dict of HMMSegment in cppjieba is not provided")
|
|
||||||
kwargs["hmm_path"] = hmm_path
|
|
||||||
if mp_path is None:
|
|
||||||
raise ValueError(
|
|
||||||
"the dict of MPSegment in cppjieba is not provided")
|
|
||||||
kwargs["mp_path"] = mp_path
|
|
||||||
if model is not None:
|
|
||||||
kwargs["model"] = model
|
|
||||||
return method(self, **kwargs)
|
|
||||||
return new_method
|
|
||||||
|
|
||||||
|
|
||||||
def check_jieba_add_word(method):
|
|
||||||
"""Wrapper method to check the parameters of jieba add word."""
|
|
||||||
|
|
||||||
@wraps(method)
|
|
||||||
def new_method(self, *args, **kwargs):
|
|
||||||
word, freq = (list(args) + 2 * [None])[:2]
|
|
||||||
|
|
||||||
if "word" in kwargs:
|
|
||||||
word = kwargs.get("word")
|
|
||||||
if "freq" in kwargs:
|
|
||||||
freq = kwargs.get("freq")
|
|
||||||
if word is None:
|
|
||||||
raise ValueError("word is not provided")
|
|
||||||
kwargs["word"] = word
|
|
||||||
if freq is not None:
|
|
||||||
check_uint32(freq)
|
|
||||||
kwargs["freq"] = freq
|
|
||||||
return method(self, **kwargs)
|
|
||||||
return new_method
|
|
||||||
|
|
||||||
|
|
||||||
def check_jieba_add_dict(method):
|
|
||||||
"""Wrapper method to check the parameters of add dict"""
|
|
||||||
|
|
||||||
@wraps(method)
|
|
||||||
def new_method(self, *args, **kwargs):
|
|
||||||
user_dict = (list(args) + [None])[0]
|
|
||||||
if "user_dict" in kwargs:
|
|
||||||
user_dict = kwargs.get("user_dict")
|
|
||||||
if user_dict is None:
|
|
||||||
raise ValueError("user_dict is not provided")
|
|
||||||
kwargs["user_dict"] = user_dict
|
|
||||||
return method(self, **kwargs)
|
|
||||||
return new_method
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
|
|
||||||
#include "common/common.h"
|
#include "common/common.h"
|
||||||
#include "dataset/kernels/text/jieba_tokenizer_op.h"
|
#include "dataset/text/kernels/jieba_tokenizer_op.h"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
|
|
||||||
#include "common/common.h"
|
#include "common/common.h"
|
||||||
#include "dataset/kernels/text/unicode_char_tokenizer_op.h"
|
#include "dataset/text/kernels/unicode_char_tokenizer_op.h"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
home is behind the world ahead
|
||||||
|
is behind home ahead world the
|
|
@ -13,7 +13,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
import mindspore.dataset.transforms.text.utils as nlp
|
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
|
|
||||||
DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
|
DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
|
||||||
|
|
|
@ -24,7 +24,6 @@ def test_flat_map_1():
|
||||||
'''
|
'''
|
||||||
DATA_FILE records the path of image folders, load the images from them.
|
DATA_FILE records the path of image folders, load the images from them.
|
||||||
'''
|
'''
|
||||||
import mindspore.dataset.transforms.text.utils as nlp
|
|
||||||
|
|
||||||
def flat_map_func(x):
|
def flat_map_func(x):
|
||||||
data_dir = x[0].item().decode('utf8')
|
data_dir = x[0].item().decode('utf8')
|
||||||
|
@ -45,7 +44,6 @@ def test_flat_map_2():
|
||||||
'''
|
'''
|
||||||
Flatten 3D structure data
|
Flatten 3D structure data
|
||||||
'''
|
'''
|
||||||
import mindspore.dataset.transforms.text.utils as nlp
|
|
||||||
|
|
||||||
def flat_map_func_1(x):
|
def flat_map_func_1(x):
|
||||||
data_dir = x[0].item().decode('utf8')
|
data_dir = x[0].item().decode('utf8')
|
||||||
|
|
|
@ -27,7 +27,7 @@ import mindspore.dataset as ds
|
||||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore.dataset.transforms.vision import Inter
|
from mindspore.dataset.transforms.vision import Inter
|
||||||
from mindspore.dataset.transforms.text import as_text
|
from mindspore.dataset.text import to_str
|
||||||
from mindspore.mindrecord import FileWriter
|
from mindspore.mindrecord import FileWriter
|
||||||
|
|
||||||
FILES_NUM = 4
|
FILES_NUM = 4
|
||||||
|
@ -73,7 +73,7 @@ def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file):
|
||||||
for item in data_set.create_dict_iterator():
|
for item in data_set.create_dict_iterator():
|
||||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
logger.info("-------------- item[file_name]: \
|
logger.info("-------------- item[file_name]: \
|
||||||
{}------------------------".format(as_text(item["file_name"])))
|
{}------------------------".format(to_str(item["file_name"])))
|
||||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
|
||||||
for item in data_set.create_dict_iterator():
|
for item in data_set.create_dict_iterator():
|
||||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
logger.info("-------------- item[file_name]: \
|
logger.info("-------------- item[file_name]: \
|
||||||
{}------------------------".format(as_text(item["file_name"])))
|
{}------------------------".format(to_str(item["file_name"])))
|
||||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
|
|
||||||
|
@ -109,7 +109,7 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
|
||||||
for item in data_set.create_dict_iterator():
|
for item in data_set.create_dict_iterator():
|
||||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
logger.info("-------------- item[file_name]: \
|
logger.info("-------------- item[file_name]: \
|
||||||
{}------------------------".format(as_text(item["file_name"])))
|
{}------------------------".format(to_str(item["file_name"])))
|
||||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file):
|
||||||
for item in data_set.create_dict_iterator():
|
for item in data_set.create_dict_iterator():
|
||||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
logger.info("-------------- item[file_name]: \
|
logger.info("-------------- item[file_name]: \
|
||||||
{}------------------------".format(as_text(item["file_name"])))
|
{}------------------------".format(to_str(item["file_name"])))
|
||||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.text as text
|
||||||
|
|
||||||
|
# this file contains "home is behind the world head" each word is 1 line
|
||||||
|
DATA_FILE = "../data/dataset/testVocab/words.txt"
|
||||||
|
VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt"
|
||||||
|
HMM_FILE = "../data/dataset/jiebadict/hmm_model.utf8"
|
||||||
|
MP_FILE = "../data/dataset/jiebadict/jieba.dict.utf8"
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_tokenized_line():
|
||||||
|
data = ds.TextFileDataset("../data/dataset/testVocab/lines.txt", shuffle=False)
|
||||||
|
jieba_op = text.JiebaTokenizer(HMM_FILE, MP_FILE, mode=text.JiebaMode.MP)
|
||||||
|
with open(VOCAB_FILE, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
word = line.split(',')[0]
|
||||||
|
jieba_op.add_word(word)
|
||||||
|
data = data.map(input_columns=["text"], operations=jieba_op)
|
||||||
|
vocab = text.Vocab.from_file(VOCAB_FILE, ",")
|
||||||
|
lookup = text.Lookup(vocab)
|
||||||
|
data = data.map(input_columns=["text"], operations=lookup)
|
||||||
|
res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14],
|
||||||
|
[11, 1, 12, 1, 10, 1, 14, 1, 13, 1, 15]], dtype=np.int32)
|
||||||
|
for i, d in enumerate(data.create_dict_iterator()):
|
||||||
|
np.testing.assert_array_equal(d["text"], res[i]), i
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_on_tokenized_line()
|
|
@ -14,8 +14,8 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
from mindspore.dataset.transforms.text.c_transforms import JiebaTokenizer
|
from mindspore.dataset.text import JiebaTokenizer
|
||||||
from mindspore.dataset.transforms.text.utils import JiebaMode, as_text
|
from mindspore.dataset.text import JiebaMode, to_str
|
||||||
|
|
||||||
DATA_FILE = "../data/dataset/testJiebaDataset/3.txt"
|
DATA_FILE = "../data/dataset/testJiebaDataset/3.txt"
|
||||||
DATA_ALL_FILE = "../data/dataset/testJiebaDataset/*"
|
DATA_ALL_FILE = "../data/dataset/testJiebaDataset/*"
|
||||||
|
@ -33,7 +33,7 @@ def test_jieba_1():
|
||||||
expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧']
|
expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧']
|
||||||
ret = []
|
ret = []
|
||||||
for i in data.create_dict_iterator():
|
for i in data.create_dict_iterator():
|
||||||
ret = as_text(i["text"])
|
ret = to_str(i["text"])
|
||||||
for index, item in enumerate(ret):
|
for index, item in enumerate(ret):
|
||||||
assert item == expect[index]
|
assert item == expect[index]
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ def test_jieba_1_1():
|
||||||
operations=jieba_op, num_parallel_workers=1)
|
operations=jieba_op, num_parallel_workers=1)
|
||||||
expect = ['今天', '天气', '太', '好', '了', '我们', '一起', '去', '外面', '玩', '吧']
|
expect = ['今天', '天气', '太', '好', '了', '我们', '一起', '去', '外面', '玩', '吧']
|
||||||
for i in data.create_dict_iterator():
|
for i in data.create_dict_iterator():
|
||||||
ret = as_text(i["text"])
|
ret = to_str(i["text"])
|
||||||
for index, item in enumerate(ret):
|
for index, item in enumerate(ret):
|
||||||
assert item == expect[index]
|
assert item == expect[index]
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ def test_jieba_1_2():
|
||||||
operations=jieba_op, num_parallel_workers=1)
|
operations=jieba_op, num_parallel_workers=1)
|
||||||
expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧']
|
expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧']
|
||||||
for i in data.create_dict_iterator():
|
for i in data.create_dict_iterator():
|
||||||
ret = as_text(i["text"])
|
ret = to_str(i["text"])
|
||||||
for index, item in enumerate(ret):
|
for index, item in enumerate(ret):
|
||||||
assert item == expect[index]
|
assert item == expect[index]
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ def test_jieba_2():
|
||||||
data = data.map(input_columns=["text"],
|
data = data.map(input_columns=["text"],
|
||||||
operations=jieba_op, num_parallel_workers=2)
|
operations=jieba_op, num_parallel_workers=2)
|
||||||
for i in data.create_dict_iterator():
|
for i in data.create_dict_iterator():
|
||||||
ret = as_text(i["text"])
|
ret = to_str(i["text"])
|
||||||
for index, item in enumerate(ret):
|
for index, item in enumerate(ret):
|
||||||
assert item == expect[index]
|
assert item == expect[index]
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ def test_jieba_2_1():
|
||||||
operations=jieba_op, num_parallel_workers=2)
|
operations=jieba_op, num_parallel_workers=2)
|
||||||
expect = ['男默女泪', '市', '长江大桥']
|
expect = ['男默女泪', '市', '长江大桥']
|
||||||
for i in data.create_dict_iterator():
|
for i in data.create_dict_iterator():
|
||||||
ret = as_text(i["text"])
|
ret = to_str(i["text"])
|
||||||
for index, item in enumerate(ret):
|
for index, item in enumerate(ret):
|
||||||
assert item == expect[index]
|
assert item == expect[index]
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ def test_jieba_2_3():
|
||||||
operations=jieba_op, num_parallel_workers=2)
|
operations=jieba_op, num_parallel_workers=2)
|
||||||
expect = ['江州', '市长', '江大桥', '参加', '了', '长江大桥', '的', '通车', '仪式']
|
expect = ['江州', '市长', '江大桥', '参加', '了', '长江大桥', '的', '通车', '仪式']
|
||||||
for i in data.create_dict_iterator():
|
for i in data.create_dict_iterator():
|
||||||
ret = as_text(i["text"])
|
ret = to_str(i["text"])
|
||||||
for index, item in enumerate(ret):
|
for index, item in enumerate(ret):
|
||||||
assert item == expect[index]
|
assert item == expect[index]
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ def test_jieba_3():
|
||||||
operations=jieba_op, num_parallel_workers=1)
|
operations=jieba_op, num_parallel_workers=1)
|
||||||
expect = ['男默女泪', '市', '长江大桥']
|
expect = ['男默女泪', '市', '长江大桥']
|
||||||
for i in data.create_dict_iterator():
|
for i in data.create_dict_iterator():
|
||||||
ret = as_text(i["text"])
|
ret = to_str(i["text"])
|
||||||
for index, item in enumerate(ret):
|
for index, item in enumerate(ret):
|
||||||
assert item == expect[index]
|
assert item == expect[index]
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ def test_jieba_3_1():
|
||||||
operations=jieba_op, num_parallel_workers=1)
|
operations=jieba_op, num_parallel_workers=1)
|
||||||
expect = ['男默女泪', '市长', '江大桥']
|
expect = ['男默女泪', '市长', '江大桥']
|
||||||
for i in data.create_dict_iterator():
|
for i in data.create_dict_iterator():
|
||||||
ret = as_text(i["text"])
|
ret = to_str(i["text"])
|
||||||
for index, item in enumerate(ret):
|
for index, item in enumerate(ret):
|
||||||
assert item == expect[index]
|
assert item == expect[index]
|
||||||
|
|
||||||
|
@ -166,7 +166,7 @@ def test_jieba_4():
|
||||||
operations=jieba_op, num_parallel_workers=1)
|
operations=jieba_op, num_parallel_workers=1)
|
||||||
expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧']
|
expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧']
|
||||||
for i in data.create_dict_iterator():
|
for i in data.create_dict_iterator():
|
||||||
ret = as_text(i["text"])
|
ret = to_str(i["text"])
|
||||||
for index, item in enumerate(ret):
|
for index, item in enumerate(ret):
|
||||||
assert item == expect[index]
|
assert item == expect[index]
|
||||||
|
|
||||||
|
@ -192,7 +192,7 @@ def test_jieba_5():
|
||||||
operations=jieba_op, num_parallel_workers=1)
|
operations=jieba_op, num_parallel_workers=1)
|
||||||
expect = ['江州', '市长', '江大桥', '参加', '了', '长江大桥', '的', '通车', '仪式']
|
expect = ['江州', '市长', '江大桥', '参加', '了', '长江大桥', '的', '通车', '仪式']
|
||||||
for i in data.create_dict_iterator():
|
for i in data.create_dict_iterator():
|
||||||
ret = as_text(i["text"])
|
ret = to_str(i["text"])
|
||||||
for index, item in enumerate(ret):
|
for index, item in enumerate(ret):
|
||||||
assert item == expect[index]
|
assert item == expect[index]
|
||||||
|
|
||||||
|
@ -203,7 +203,7 @@ def gen():
|
||||||
|
|
||||||
|
|
||||||
def pytoken_op(input_data):
|
def pytoken_op(input_data):
|
||||||
te = str(as_text(input_data))
|
te = str(to_str(input_data))
|
||||||
tokens = []
|
tokens = []
|
||||||
tokens.append(te[:5].encode("UTF8"))
|
tokens.append(te[:5].encode("UTF8"))
|
||||||
tokens.append(te[5:10].encode("UTF8"))
|
tokens.append(te[5:10].encode("UTF8"))
|
||||||
|
@ -217,7 +217,7 @@ def test_jieba_6():
|
||||||
operations=pytoken_op, num_parallel_workers=1)
|
operations=pytoken_op, num_parallel_workers=1)
|
||||||
expect = ['今天天气太', '好了我们一', '起去外面玩吧']
|
expect = ['今天天气太', '好了我们一', '起去外面玩吧']
|
||||||
for i in data.create_dict_iterator():
|
for i in data.create_dict_iterator():
|
||||||
ret = as_text(i["text"])
|
ret = to_str(i["text"])
|
||||||
for index, item in enumerate(ret):
|
for index, item in enumerate(ret):
|
||||||
assert item == expect[index]
|
assert item == expect[index]
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,8 @@ import mindspore._c_dataengine as cde
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from mindspore.dataset.text import to_str, to_bytes
|
||||||
|
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
|
@ -64,7 +66,8 @@ def test_map():
|
||||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||||
|
|
||||||
def split(b):
|
def split(b):
|
||||||
splits = b.item().decode("utf8").split()
|
s = to_str(b)
|
||||||
|
splits = s.item().split()
|
||||||
return np.array(splits, dtype='S')
|
return np.array(splits, dtype='S')
|
||||||
|
|
||||||
data = data.map(input_columns=["col"], operations=split)
|
data = data.map(input_columns=["col"], operations=split)
|
||||||
|
@ -73,11 +76,20 @@ def test_map():
|
||||||
np.testing.assert_array_equal(d[0], expected)
|
np.testing.assert_array_equal(d[0], expected)
|
||||||
|
|
||||||
|
|
||||||
def as_str(arr):
|
def test_map2():
|
||||||
def decode(s): return s.decode("utf8")
|
def gen():
|
||||||
|
yield np.array(["ab cde 121"], dtype='S'),
|
||||||
|
|
||||||
decode_v = np.vectorize(decode)
|
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||||
return decode_v(arr)
|
|
||||||
|
def upper(b):
|
||||||
|
out = np.char.upper(b)
|
||||||
|
return out
|
||||||
|
|
||||||
|
data = data.map(input_columns=["col"], operations=upper)
|
||||||
|
expected = np.array(["AB CDE 121"], dtype='S')
|
||||||
|
for d in data:
|
||||||
|
np.testing.assert_array_equal(d[0], expected)
|
||||||
|
|
||||||
|
|
||||||
line = np.array(["This is a text file.",
|
line = np.array(["This is a text file.",
|
||||||
|
@ -105,9 +117,9 @@ def test_tfrecord1():
|
||||||
assert d["line"].shape == line[i].shape
|
assert d["line"].shape == line[i].shape
|
||||||
assert d["words"].shape == words[i].shape
|
assert d["words"].shape == words[i].shape
|
||||||
assert d["chinese"].shape == chinese[i].shape
|
assert d["chinese"].shape == chinese[i].shape
|
||||||
np.testing.assert_array_equal(line[i], as_str(d["line"]))
|
np.testing.assert_array_equal(line[i], to_str(d["line"]))
|
||||||
np.testing.assert_array_equal(words[i], as_str(d["words"]))
|
np.testing.assert_array_equal(words[i], to_str(d["words"]))
|
||||||
np.testing.assert_array_equal(chinese[i], as_str(d["chinese"]))
|
np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
|
||||||
|
|
||||||
|
|
||||||
def test_tfrecord2():
|
def test_tfrecord2():
|
||||||
|
@ -117,9 +129,9 @@ def test_tfrecord2():
|
||||||
assert d["line"].shape == line[i].shape
|
assert d["line"].shape == line[i].shape
|
||||||
assert d["words"].shape == words[i].shape
|
assert d["words"].shape == words[i].shape
|
||||||
assert d["chinese"].shape == chinese[i].shape
|
assert d["chinese"].shape == chinese[i].shape
|
||||||
np.testing.assert_array_equal(line[i], as_str(d["line"]))
|
np.testing.assert_array_equal(line[i], to_str(d["line"]))
|
||||||
np.testing.assert_array_equal(words[i], as_str(d["words"]))
|
np.testing.assert_array_equal(words[i], to_str(d["words"]))
|
||||||
np.testing.assert_array_equal(chinese[i], as_str(d["chinese"]))
|
np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
|
||||||
|
|
||||||
|
|
||||||
def test_tfrecord3():
|
def test_tfrecord3():
|
||||||
|
@ -134,9 +146,9 @@ def test_tfrecord3():
|
||||||
assert d["line"].shape == line[i].shape
|
assert d["line"].shape == line[i].shape
|
||||||
assert d["words"].shape == words[i].reshape([2, 2]).shape
|
assert d["words"].shape == words[i].reshape([2, 2]).shape
|
||||||
assert d["chinese"].shape == chinese[i].shape
|
assert d["chinese"].shape == chinese[i].shape
|
||||||
np.testing.assert_array_equal(line[i], as_str(d["line"]))
|
np.testing.assert_array_equal(line[i], to_str(d["line"]))
|
||||||
np.testing.assert_array_equal(words[i].reshape([2, 2]), as_str(d["words"]))
|
np.testing.assert_array_equal(words[i].reshape([2, 2]), to_str(d["words"]))
|
||||||
np.testing.assert_array_equal(chinese[i], as_str(d["chinese"]))
|
np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
|
||||||
|
|
||||||
|
|
||||||
def create_text_mindrecord():
|
def create_text_mindrecord():
|
||||||
|
@ -166,16 +178,17 @@ def test_mindrecord():
|
||||||
for i, d in enumerate(data.create_dict_iterator()):
|
for i, d in enumerate(data.create_dict_iterator()):
|
||||||
assert d["english"].shape == line[i].shape
|
assert d["english"].shape == line[i].shape
|
||||||
assert d["chinese"].shape == chinese[i].shape
|
assert d["chinese"].shape == chinese[i].shape
|
||||||
np.testing.assert_array_equal(line[i], as_str(d["english"]))
|
np.testing.assert_array_equal(line[i], to_str(d["english"]))
|
||||||
np.testing.assert_array_equal(chinese[i], as_str(d["chinese"]))
|
np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# test_generator()
|
test_generator()
|
||||||
# test_basic()
|
test_basic()
|
||||||
# test_batching_strings()
|
test_batching_strings()
|
||||||
test_map()
|
test_map()
|
||||||
# test_tfrecord1()
|
test_map2()
|
||||||
# test_tfrecord2()
|
test_tfrecord1()
|
||||||
# test_tfrecord3()
|
test_tfrecord2()
|
||||||
# test_mindrecord()
|
test_tfrecord3()
|
||||||
|
test_mindrecord()
|
||||||
|
|
|
@ -17,8 +17,7 @@ Testing UnicodeCharTokenizer op in DE
|
||||||
"""
|
"""
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
import mindspore.dataset.transforms.text.c_transforms as nlp
|
import mindspore.dataset.text as nlp
|
||||||
import mindspore.dataset.transforms.text.utils as nlp_util
|
|
||||||
|
|
||||||
DATA_FILE = "../data/dataset/testTokenizerData/1.txt"
|
DATA_FILE = "../data/dataset/testTokenizerData/1.txt"
|
||||||
|
|
||||||
|
@ -43,7 +42,7 @@ def test_unicode_char_tokenizer():
|
||||||
dataset = dataset.map(operations=tokenizer)
|
dataset = dataset.map(operations=tokenizer)
|
||||||
tokens = []
|
tokens = []
|
||||||
for i in dataset.create_dict_iterator():
|
for i in dataset.create_dict_iterator():
|
||||||
text = nlp_util.as_text(i['text']).tolist()
|
text = nlp.to_str(i['text']).tolist()
|
||||||
tokens.append(text)
|
tokens.append(text)
|
||||||
logger.info("The out tokens is : {}".format(tokens))
|
logger.info("The out tokens is : {}".format(tokens))
|
||||||
assert split_by_unicode_char(input_strs) == tokens
|
assert split_by_unicode_char(input_strs) == tokens
|
||||||
|
|
Loading…
Reference in New Issue