forked from OSSInnovation/mindspore
add code
This commit is contained in:
parent
c44939afc8
commit
0af6d75761
|
@ -27,17 +27,34 @@ namespace dataset {
|
||||||
SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::shared_ptr<SentencePieceVocab> vocab,
|
SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::shared_ptr<SentencePieceVocab> vocab,
|
||||||
const SPieceTokenizerLoadType load_type,
|
const SPieceTokenizerLoadType load_type,
|
||||||
const SPieceTokenizerOutType out_type)
|
const SPieceTokenizerOutType out_type)
|
||||||
: vocab_(vocab), load_type_(load_type), out_type_(out_type) {}
|
: vocab_(vocab), load_type_(load_type), out_type_(out_type) {
|
||||||
|
auto status = processor_.LoadFromSerializedProto(vocab_.get()->model_proto());
|
||||||
|
if (!status.ok()) {
|
||||||
|
model_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "parser vocab model filed.");
|
||||||
|
} else {
|
||||||
|
model_status_ = Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::string &model_path, const std::string &model_filename,
|
SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::string &model_path, const std::string &model_filename,
|
||||||
const SPieceTokenizerLoadType load_type,
|
const SPieceTokenizerLoadType load_type,
|
||||||
const SPieceTokenizerOutType out_type)
|
const SPieceTokenizerOutType out_type)
|
||||||
: load_type_(load_type), out_type_(out_type) {
|
: load_type_(load_type), out_type_(out_type) {
|
||||||
(void)GetModelRealPath(model_path, model_filename);
|
(void)GetModelRealPath(model_path, model_filename);
|
||||||
|
auto status = processor_.Load(file_path_);
|
||||||
|
if (!status.ok()) {
|
||||||
|
model_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "load vocab model filed.");
|
||||||
|
} else {
|
||||||
|
model_status_ = Status::OK();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
IO_CHECK(input, output);
|
IO_CHECK(input, output);
|
||||||
|
if (!model_status_.IsOk()) {
|
||||||
|
return model_status_;
|
||||||
|
}
|
||||||
|
|
||||||
if (input->Rank() != 0 || input->type() != DataType::DE_STRING) {
|
if (input->Rank() != 0 || input->type() != DataType::DE_STRING) {
|
||||||
RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor");
|
RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor");
|
||||||
}
|
}
|
||||||
|
@ -45,18 +62,6 @@ Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, s
|
||||||
std::string_view sentence_v;
|
std::string_view sentence_v;
|
||||||
RETURN_IF_NOT_OK(input->GetItemAt(&sentence_v, {}));
|
RETURN_IF_NOT_OK(input->GetItemAt(&sentence_v, {}));
|
||||||
std::string sentence{sentence_v};
|
std::string sentence{sentence_v};
|
||||||
if (load_type_ == SPieceTokenizerLoadType::kFile) {
|
|
||||||
auto status = processor_.Load(file_path_);
|
|
||||||
if (!status.ok()) {
|
|
||||||
RETURN_STATUS_UNEXPECTED("load sentence piece model failed.");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
RETURN_UNEXPECTED_IF_NULL(vocab_);
|
|
||||||
auto status = processor_.LoadFromSerializedProto(vocab_.get()->model_proto());
|
|
||||||
if (!status.ok()) {
|
|
||||||
RETURN_STATUS_UNEXPECTED("sentence piece load model failed.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (out_type_ == SPieceTokenizerOutType::kString) {
|
if (out_type_ == SPieceTokenizerOutType::kString) {
|
||||||
std::vector<std::string> pieces;
|
std::vector<std::string> pieces;
|
||||||
|
|
|
@ -58,6 +58,7 @@ class SentencePieceTokenizerOp : public TensorOp {
|
||||||
std::string file_path_;
|
std::string file_path_;
|
||||||
SPieceTokenizerLoadType load_type_;
|
SPieceTokenizerLoadType load_type_;
|
||||||
sentencepiece::SentencePieceProcessor processor_;
|
sentencepiece::SentencePieceProcessor processor_;
|
||||||
|
Status model_status_;
|
||||||
};
|
};
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
import copy
|
||||||
import mindspore.dataset.text as text
|
import mindspore.dataset.text as text
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
from mindspore.dataset.text import SentencePieceModel, to_str, SPieceTokenizerOutType
|
from mindspore.dataset.text import SentencePieceModel, to_str, SPieceTokenizerOutType
|
||||||
|
@ -121,6 +121,44 @@ def test_build_from_dataset():
|
||||||
assert value == expect[key]
|
assert value == expect[key]
|
||||||
|
|
||||||
|
|
||||||
|
def apply_func(dataset):
|
||||||
|
input_columns = ['text']
|
||||||
|
output_columns = ['text2']
|
||||||
|
dataset = dataset.rename(input_columns, output_columns)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def zip_test(dataset):
|
||||||
|
dataset_1 = copy.deepcopy(dataset)
|
||||||
|
dataset_2 = copy.deepcopy(dataset)
|
||||||
|
dataset_1 = dataset_1.apply(apply_func)
|
||||||
|
dataset_zip = ds.zip((dataset_1, dataset_2))
|
||||||
|
expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.']
|
||||||
|
for i in dataset_zip.create_dict_iterator():
|
||||||
|
ret = to_str(i["text"])
|
||||||
|
for key, value in enumerate(ret):
|
||||||
|
assert value == expect[key]
|
||||||
|
|
||||||
|
|
||||||
|
def concat_test(dataset):
|
||||||
|
dataset_1 = copy.deepcopy(dataset)
|
||||||
|
dataset = dataset.concat(dataset_1)
|
||||||
|
expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.']
|
||||||
|
for i in dataset.create_dict_iterator():
|
||||||
|
ret = to_str(i["text"])
|
||||||
|
for key, value in enumerate(ret):
|
||||||
|
assert value == expect[key]
|
||||||
|
|
||||||
|
def test_with_zip_concat():
|
||||||
|
data = ds.TextFileDataset(VOCAB_FILE, shuffle=False)
|
||||||
|
vocab = text.SentencePieceVocab.from_dataset(data, [""], 5000, 0.9995, SentencePieceModel.UNIGRAM, {})
|
||||||
|
tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING)
|
||||||
|
dataset = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
||||||
|
dataset = dataset.map(operations=tokenizer, num_parallel_workers=2)
|
||||||
|
zip_test(dataset)
|
||||||
|
concat_test(dataset)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_from_vocab_to_str_UNIGRAM()
|
test_from_vocab_to_str_UNIGRAM()
|
||||||
test_from_vocab_to_str_BPE()
|
test_from_vocab_to_str_BPE()
|
||||||
|
@ -130,3 +168,4 @@ if __name__ == "__main__":
|
||||||
test_from_file_to_str()
|
test_from_file_to_str()
|
||||||
test_from_file_to_int()
|
test_from_file_to_int()
|
||||||
test_build_from_dataset()
|
test_build_from_dataset()
|
||||||
|
test_with_zip_concat()
|
||||||
|
|
Loading…
Reference in New Issue