From f99204b292062cd4cbf4065cf7d63a2ce65091b3 Mon Sep 17 00:00:00 2001 From: YangLuo Date: Thu, 18 Mar 2021 15:28:08 +0800 Subject: [PATCH] fix python tokenizer --- mindspore/dataset/text/transforms.py | 4 +++- mindspore/dataset/text/utils.py | 2 +- tests/ut/python/dataset/test_eager_text.py | 13 +++++++++---- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/mindspore/dataset/text/transforms.py b/mindspore/dataset/text/transforms.py index 5e5342a000e..8b34ff280e1 100644 --- a/mindspore/dataset/text/transforms.py +++ b/mindspore/dataset/text/transforms.py @@ -533,7 +533,9 @@ class PythonTokenizer: self.random = False def __call__(self, in_array): - if not isinstance(in_array, str): + if not isinstance(in_array, np.ndarray): + raise TypeError("input should be a NumPy array. Got {}.".format(type(in_array))) + if in_array.dtype.type is np.bytes_: in_array = to_str(in_array) tokens = self.tokenizer(in_array) return tokens diff --git a/mindspore/dataset/text/utils.py b/mindspore/dataset/text/utils.py index d8fb88d4b25..5fcafb16205 100644 --- a/mindspore/dataset/text/utils.py +++ b/mindspore/dataset/text/utils.py @@ -216,7 +216,7 @@ def to_str(array, encoding='utf8'): """ if not isinstance(array, np.ndarray): - raise ValueError('input should be a NumPy array.') + raise TypeError('input should be a NumPy array.') return np.char.decode(array, encoding) diff --git a/tests/ut/python/dataset/test_eager_text.py b/tests/ut/python/dataset/test_eager_text.py index 76a26e55040..af90b660582 100644 --- a/tests/ut/python/dataset/test_eager_text.py +++ b/tests/ut/python/dataset/test_eager_text.py @@ -52,12 +52,17 @@ def test_python_tokenizer(): if not words: return [""] return words - txt = "Welcome to Beijing !" - txt = T.PythonTokenizer(my_tokenizer)(txt) - logger.info("Tokenize result: {}".format(txt)) + txt1 = np.array("Welcome to Beijing !".encode()) + txt1 = T.PythonTokenizer(my_tokenizer)(txt1) + logger.info("Tokenize result: {}".format(txt1)) + + txt2 = np.array("Welcome to Beijing !") + txt2 = T.PythonTokenizer(my_tokenizer)(txt2) + logger.info("Tokenize result: {}".format(txt2)) expected = ['Welcome', 'to', 'Beijing', '!'] - np.testing.assert_equal(txt, expected) + np.testing.assert_equal(txt1, expected) + np.testing.assert_equal(txt2, expected) if __name__ == '__main__':