This commit is contained in:
xulei2020 2020-07-24 18:43:51 +08:00
parent 50e20e4042
commit c43bc92d7c
3 changed files with 59 additions and 14 deletions

View File

@ -27,17 +27,34 @@ namespace dataset {
SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::shared_ptr<SentencePieceVocab> vocab,
const SPieceTokenizerLoadType load_type,
const SPieceTokenizerOutType out_type)
: vocab_(vocab), load_type_(load_type), out_type_(out_type) {}
: vocab_(vocab), load_type_(load_type), out_type_(out_type) {
auto status = processor_.LoadFromSerializedProto(vocab_.get()->model_proto());
if (!status.ok()) {
model_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "parser vocab model filed.");
} else {
model_status_ = Status::OK();
}
}
SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::string &model_path, const std::string &model_filename,
const SPieceTokenizerLoadType load_type,
const SPieceTokenizerOutType out_type)
: load_type_(load_type), out_type_(out_type) {
(void)GetModelRealPath(model_path, model_filename);
auto status = processor_.Load(file_path_);
if (!status.ok()) {
model_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "load vocab model filed.");
} else {
model_status_ = Status::OK();
}
}
Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
if (!model_status_.IsOk()) {
return model_status_;
}
if (input->Rank() != 0 || input->type() != DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor");
}
@ -45,18 +62,6 @@ Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, s
std::string_view sentence_v;
RETURN_IF_NOT_OK(input->GetItemAt(&sentence_v, {}));
std::string sentence{sentence_v};
if (load_type_ == SPieceTokenizerLoadType::kFile) {
auto status = processor_.Load(file_path_);
if (!status.ok()) {
RETURN_STATUS_UNEXPECTED("load sentence piece model failed.");
}
} else {
RETURN_UNEXPECTED_IF_NULL(vocab_);
auto status = processor_.LoadFromSerializedProto(vocab_.get()->model_proto());
if (!status.ok()) {
RETURN_STATUS_UNEXPECTED("sentence piece load model failed.");
}
}
if (out_type_ == SPieceTokenizerOutType::kString) {
std::vector<std::string> pieces;

View File

@ -58,6 +58,7 @@ class SentencePieceTokenizerOp : public TensorOp {
std::string file_path_;
SPieceTokenizerLoadType load_type_;
sentencepiece::SentencePieceProcessor processor_;
Status model_status_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy
import mindspore.dataset.text as text
import mindspore.dataset as ds
from mindspore.dataset.text import SentencePieceModel, to_str, SPieceTokenizerOutType
@ -84,9 +84,48 @@ def test_build_from_dataset():
assert value == expect[key]
def apply_func(dataset):
input_columns = ['text']
output_columns = ['text2']
dataset = dataset.rename(input_columns, output_columns)
return dataset
def zip_test(dataset):
dataset_1 = copy.deepcopy(dataset)
dataset_2 = copy.deepcopy(dataset)
dataset_1 = dataset_1.apply(apply_func)
dataset_zip = ds.zip((dataset_1, dataset_2))
expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.']
for i in dataset_zip.create_dict_iterator():
ret = to_str(i["text"])
for key, value in enumerate(ret):
assert value == expect[key]
def concat_test(dataset):
dataset_1 = copy.deepcopy(dataset)
dataset = dataset.concat(dataset_1)
expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.']
for i in dataset.create_dict_iterator():
ret = to_str(i["text"])
for key, value in enumerate(ret):
assert value == expect[key]
def test_with_zip_concat():
data = ds.TextFileDataset(VOCAB_FILE, shuffle=False)
vocab = text.SentencePieceVocab.from_dataset(data, [""], 5000, 0.9995, SentencePieceModel.UNIGRAM, {})
tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING)
dataset = ds.TextFileDataset(DATA_FILE, shuffle=False)
dataset = dataset.map(operations=tokenizer, num_parallel_workers=2)
zip_test(dataset)
concat_test(dataset)
if __name__ == "__main__":
test_from_vocab_to_str()
test_from_vocab_to_int()
test_from_file_to_str()
test_from_file_to_int()
test_build_from_dataset()
test_with_zip_concat()