From 0af6d75761ee48eee8ee84de36ee8076acb35fa4 Mon Sep 17 00:00:00 2001 From: xulei2020 <“xulei83@huawei.com”> Date: Fri, 24 Jul 2020 18:43:51 +0800 Subject: [PATCH] add code --- .../kernels/sentence_piece_tokenizer_op.cc | 31 ++++++++------ .../kernels/sentence_piece_tokenizer_op.h | 1 + .../dataset/test_sentencepiece_tokenizer.py | 41 ++++++++++++++++++- 3 files changed, 59 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc index 42fefa20068..fba3770d385 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc @@ -27,17 +27,34 @@ namespace dataset { SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::shared_ptr vocab, const SPieceTokenizerLoadType load_type, const SPieceTokenizerOutType out_type) - : vocab_(vocab), load_type_(load_type), out_type_(out_type) {} + : vocab_(vocab), load_type_(load_type), out_type_(out_type) { + auto status = processor_.LoadFromSerializedProto(vocab_.get()->model_proto()); + if (!status.ok()) { + model_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "parser vocab model filed."); + } else { + model_status_ = Status::OK(); + } +} SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::string &model_path, const std::string &model_filename, const SPieceTokenizerLoadType load_type, const SPieceTokenizerOutType out_type) : load_type_(load_type), out_type_(out_type) { (void)GetModelRealPath(model_path, model_filename); + auto status = processor_.Load(file_path_); + if (!status.ok()) { + model_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "load vocab model filed."); + } else { + model_status_ = Status::OK(); + } } Status SentencePieceTokenizerOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); + if (!model_status_.IsOk()) { + return model_status_; + } + if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor"); } @@ -45,18 +62,6 @@ Status SentencePieceTokenizerOp::Compute(const std::shared_ptr &input, s std::string_view sentence_v; RETURN_IF_NOT_OK(input->GetItemAt(&sentence_v, {})); std::string sentence{sentence_v}; - if (load_type_ == SPieceTokenizerLoadType::kFile) { - auto status = processor_.Load(file_path_); - if (!status.ok()) { - RETURN_STATUS_UNEXPECTED("load sentence piece model failed."); - } - } else { - RETURN_UNEXPECTED_IF_NULL(vocab_); - auto status = processor_.LoadFromSerializedProto(vocab_.get()->model_proto()); - if (!status.ok()) { - RETURN_STATUS_UNEXPECTED("sentence piece load model failed."); - } - } if (out_type_ == SPieceTokenizerOutType::kString) { std::vector pieces; diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h index 130842cb774..3cc97078cb7 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h @@ -58,6 +58,7 @@ class SentencePieceTokenizerOp : public TensorOp { std::string file_path_; SPieceTokenizerLoadType load_type_; sentencepiece::SentencePieceProcessor processor_; + Status model_status_; }; } // namespace dataset } // namespace mindspore diff --git a/tests/ut/python/dataset/test_sentencepiece_tokenizer.py b/tests/ut/python/dataset/test_sentencepiece_tokenizer.py index e78c58e5a33..d50ed01e7be 100644 --- a/tests/ut/python/dataset/test_sentencepiece_tokenizer.py +++ b/tests/ut/python/dataset/test_sentencepiece_tokenizer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - +import copy import mindspore.dataset.text as text import mindspore.dataset as ds from mindspore.dataset.text import SentencePieceModel, to_str, SPieceTokenizerOutType @@ -121,6 +121,44 @@ def test_build_from_dataset(): assert value == expect[key] +def apply_func(dataset): + input_columns = ['text'] + output_columns = ['text2'] + dataset = dataset.rename(input_columns, output_columns) + return dataset + + +def zip_test(dataset): + dataset_1 = copy.deepcopy(dataset) + dataset_2 = copy.deepcopy(dataset) + dataset_1 = dataset_1.apply(apply_func) + dataset_zip = ds.zip((dataset_1, dataset_2)) + expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.'] + for i in dataset_zip.create_dict_iterator(): + ret = to_str(i["text"]) + for key, value in enumerate(ret): + assert value == expect[key] + + +def concat_test(dataset): + dataset_1 = copy.deepcopy(dataset) + dataset = dataset.concat(dataset_1) + expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.'] + for i in dataset.create_dict_iterator(): + ret = to_str(i["text"]) + for key, value in enumerate(ret): + assert value == expect[key] + +def test_with_zip_concat(): + data = ds.TextFileDataset(VOCAB_FILE, shuffle=False) + vocab = text.SentencePieceVocab.from_dataset(data, [""], 5000, 0.9995, SentencePieceModel.UNIGRAM, {}) + tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING) + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + dataset = dataset.map(operations=tokenizer, num_parallel_workers=2) + zip_test(dataset) + concat_test(dataset) + + if __name__ == "__main__": test_from_vocab_to_str_UNIGRAM() test_from_vocab_to_str_BPE() @@ -130,3 +168,4 @@ if __name__ == "__main__": test_from_file_to_str() test_from_file_to_int() test_build_from_dataset() + test_with_zip_concat()