forked from mindspore-Ecosystem/mindspore
add code
This commit is contained in:
parent
c44939afc8
commit
0af6d75761
|
@ -27,17 +27,34 @@ namespace dataset {
|
|||
SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::shared_ptr<SentencePieceVocab> vocab,
|
||||
const SPieceTokenizerLoadType load_type,
|
||||
const SPieceTokenizerOutType out_type)
|
||||
: vocab_(vocab), load_type_(load_type), out_type_(out_type) {}
|
||||
: vocab_(vocab), load_type_(load_type), out_type_(out_type) {
|
||||
auto status = processor_.LoadFromSerializedProto(vocab_.get()->model_proto());
|
||||
if (!status.ok()) {
|
||||
model_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "parser vocab model filed.");
|
||||
} else {
|
||||
model_status_ = Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::string &model_path, const std::string &model_filename,
|
||||
const SPieceTokenizerLoadType load_type,
|
||||
const SPieceTokenizerOutType out_type)
|
||||
: load_type_(load_type), out_type_(out_type) {
|
||||
(void)GetModelRealPath(model_path, model_filename);
|
||||
auto status = processor_.Load(file_path_);
|
||||
if (!status.ok()) {
|
||||
model_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "load vocab model filed.");
|
||||
} else {
|
||||
model_status_ = Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
if (!model_status_.IsOk()) {
|
||||
return model_status_;
|
||||
}
|
||||
|
||||
if (input->Rank() != 0 || input->type() != DataType::DE_STRING) {
|
||||
RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor");
|
||||
}
|
||||
|
@ -45,18 +62,6 @@ Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, s
|
|||
std::string_view sentence_v;
|
||||
RETURN_IF_NOT_OK(input->GetItemAt(&sentence_v, {}));
|
||||
std::string sentence{sentence_v};
|
||||
if (load_type_ == SPieceTokenizerLoadType::kFile) {
|
||||
auto status = processor_.Load(file_path_);
|
||||
if (!status.ok()) {
|
||||
RETURN_STATUS_UNEXPECTED("load sentence piece model failed.");
|
||||
}
|
||||
} else {
|
||||
RETURN_UNEXPECTED_IF_NULL(vocab_);
|
||||
auto status = processor_.LoadFromSerializedProto(vocab_.get()->model_proto());
|
||||
if (!status.ok()) {
|
||||
RETURN_STATUS_UNEXPECTED("sentence piece load model failed.");
|
||||
}
|
||||
}
|
||||
|
||||
if (out_type_ == SPieceTokenizerOutType::kString) {
|
||||
std::vector<std::string> pieces;
|
||||
|
|
|
@ -58,6 +58,7 @@ class SentencePieceTokenizerOp : public TensorOp {
|
|||
std::string file_path_;
|
||||
SPieceTokenizerLoadType load_type_;
|
||||
sentencepiece::SentencePieceProcessor processor_;
|
||||
Status model_status_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import copy
|
||||
import mindspore.dataset.text as text
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.dataset.text import SentencePieceModel, to_str, SPieceTokenizerOutType
|
||||
|
@ -121,6 +121,44 @@ def test_build_from_dataset():
|
|||
assert value == expect[key]
|
||||
|
||||
|
||||
def apply_func(dataset):
|
||||
input_columns = ['text']
|
||||
output_columns = ['text2']
|
||||
dataset = dataset.rename(input_columns, output_columns)
|
||||
return dataset
|
||||
|
||||
|
||||
def zip_test(dataset):
|
||||
dataset_1 = copy.deepcopy(dataset)
|
||||
dataset_2 = copy.deepcopy(dataset)
|
||||
dataset_1 = dataset_1.apply(apply_func)
|
||||
dataset_zip = ds.zip((dataset_1, dataset_2))
|
||||
expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.']
|
||||
for i in dataset_zip.create_dict_iterator():
|
||||
ret = to_str(i["text"])
|
||||
for key, value in enumerate(ret):
|
||||
assert value == expect[key]
|
||||
|
||||
|
||||
def concat_test(dataset):
|
||||
dataset_1 = copy.deepcopy(dataset)
|
||||
dataset = dataset.concat(dataset_1)
|
||||
expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.']
|
||||
for i in dataset.create_dict_iterator():
|
||||
ret = to_str(i["text"])
|
||||
for key, value in enumerate(ret):
|
||||
assert value == expect[key]
|
||||
|
||||
def test_with_zip_concat():
|
||||
data = ds.TextFileDataset(VOCAB_FILE, shuffle=False)
|
||||
vocab = text.SentencePieceVocab.from_dataset(data, [""], 5000, 0.9995, SentencePieceModel.UNIGRAM, {})
|
||||
tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING)
|
||||
dataset = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
||||
dataset = dataset.map(operations=tokenizer, num_parallel_workers=2)
|
||||
zip_test(dataset)
|
||||
concat_test(dataset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_from_vocab_to_str_UNIGRAM()
|
||||
test_from_vocab_to_str_BPE()
|
||||
|
@ -130,3 +168,4 @@ if __name__ == "__main__":
|
|||
test_from_file_to_str()
|
||||
test_from_file_to_int()
|
||||
test_build_from_dataset()
|
||||
test_with_zip_concat()
|
||||
|
|
Loading…
Reference in New Issue