forked from mindspore-Ecosystem/mindspore
!1807 Implemented Ngram TensorOp for dataset
Merge pull request !1807 from ZiruiWu/ngram_dev
This commit is contained in:
commit
5c21616293
|
@ -59,6 +59,7 @@
|
||||||
#include "dataset/engine/gnn/graph.h"
|
#include "dataset/engine/gnn/graph.h"
|
||||||
#include "dataset/kernels/data/to_float16_op.h"
|
#include "dataset/kernels/data/to_float16_op.h"
|
||||||
#include "dataset/text/kernels/jieba_tokenizer_op.h"
|
#include "dataset/text/kernels/jieba_tokenizer_op.h"
|
||||||
|
#include "dataset/text/kernels/ngram_op.h"
|
||||||
#include "dataset/text/kernels/unicode_char_tokenizer_op.h"
|
#include "dataset/text/kernels/unicode_char_tokenizer_op.h"
|
||||||
#include "dataset/text/vocab.h"
|
#include "dataset/text/vocab.h"
|
||||||
#include "dataset/text/kernels/lookup_op.h"
|
#include "dataset/text/kernels/lookup_op.h"
|
||||||
|
@ -430,6 +431,11 @@ void bindTensorOps5(py::module *m) {
|
||||||
"Tensor operation to LookUp each word")
|
"Tensor operation to LookUp each word")
|
||||||
.def(py::init<std::shared_ptr<Vocab>, WordIdType>(), py::arg("vocab"), py::arg("unknown"))
|
.def(py::init<std::shared_ptr<Vocab>, WordIdType>(), py::arg("vocab"), py::arg("unknown"))
|
||||||
.def(py::init<std::shared_ptr<Vocab>>(), py::arg("vocab"));
|
.def(py::init<std::shared_ptr<Vocab>>(), py::arg("vocab"));
|
||||||
|
(void)py::class_<NgramOp, TensorOp, std::shared_ptr<NgramOp>>(*m, "NgramOp", "TensorOp performs ngram mapping")
|
||||||
|
.def(py::init<const std::vector<int32_t> &, int32_t, int32_t, const std::string &, const std::string &,
|
||||||
|
const std::string &>(),
|
||||||
|
py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"), py::arg("r_pad_token"),
|
||||||
|
py::arg("separator"));
|
||||||
}
|
}
|
||||||
|
|
||||||
void bindSamplerOps(py::module *m) {
|
void bindSamplerOps(py::module *m) {
|
||||||
|
|
|
@ -4,4 +4,5 @@ add_library(text-kernels OBJECT
|
||||||
lookup_op.cc
|
lookup_op.cc
|
||||||
jieba_tokenizer_op.cc
|
jieba_tokenizer_op.cc
|
||||||
unicode_char_tokenizer_op.cc
|
unicode_char_tokenizer_op.cc
|
||||||
|
ngram_op.cc
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef DATASET_NLP_KERNELS_LOOKUP_OP_H_
|
#ifndef DATASET_TEXT_KERNELS_LOOKUP_OP_H_
|
||||||
#define DATASET_NLP_KERNELS_LOOKUP_OP_H_
|
#define DATASET_TEXT_KERNELS_LOOKUP_OP_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -61,4 +61,4 @@ class LookupOp : public TensorOp {
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // DATASET_NLP_KERNELS_LOOKUP_OP_H_
|
#endif // DATASET_TEXT_KERNELS_LOOKUP_OP_H_
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "dataset/text/kernels/ngram_op.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
NgramOp::NgramOp(const std::vector<int32_t> &ngrams, int32_t l_len, int32_t r_len, const std::string &l_pad,
|
||||||
|
const std::string &r_pad, const std::string &separator)
|
||||||
|
: ngrams_(ngrams),
|
||||||
|
l_len_(l_len),
|
||||||
|
r_len_(r_len),
|
||||||
|
l_pad_with_sp_(l_pad + separator),
|
||||||
|
r_pad_with_sp_(r_pad + separator),
|
||||||
|
separator_(separator) {}
|
||||||
|
|
||||||
|
Status NgramOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING && input->Rank() == 1, "Not a 1-D str Tensor");
|
||||||
|
std::vector<int32_t> offsets; // offsets for each str
|
||||||
|
std::vector<std::string> res; // holds the result of ngrams
|
||||||
|
std::string str_buffer; // concat all pad tokens with string interleaved with separators
|
||||||
|
res.reserve(input->shape().NumOfElements()); // this should be more than enough
|
||||||
|
offsets.reserve(1 + l_len_ + r_len_ + input->shape().NumOfElements());
|
||||||
|
str_buffer.reserve(l_pad_with_sp_.size() * l_len_ + r_pad_with_sp_.size() * r_len_ + input->SizeInBytes());
|
||||||
|
offsets.push_back(str_buffer.size()); // insert 0 as the starting pos
|
||||||
|
for (int i = 0; i < l_len_; i++) offsets.push_back((str_buffer += l_pad_with_sp_).size());
|
||||||
|
|
||||||
|
for (auto itr = input->begin<std::string_view>(); itr != input->end<std::string_view>(); itr++) {
|
||||||
|
str_buffer += (*itr);
|
||||||
|
str_buffer += separator_;
|
||||||
|
offsets.push_back(str_buffer.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < r_len_; i++) offsets.push_back((str_buffer += r_pad_with_sp_).size());
|
||||||
|
|
||||||
|
for (auto n : ngrams_) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "n gram needs to be a positive number.\n");
|
||||||
|
int32_t start_ind = l_len_ - std::min(l_len_, n - 1);
|
||||||
|
int32_t end_ind = offsets.size() - r_len_ + std::min(r_len_, n - 1);
|
||||||
|
if (end_ind - start_ind < n) {
|
||||||
|
res.emplace_back(std::string()); // push back empty string
|
||||||
|
} else {
|
||||||
|
for (int i = start_ind; i < end_ind - n; i++) {
|
||||||
|
res.emplace_back(str_buffer.substr(offsets[i], offsets[i + n] - offsets[i] - separator_.size()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, res, TensorShape({static_cast<dsize_t>(res.size())})));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void NgramOp::Print(std::ostream &out) const {
|
||||||
|
out << "NgramOp: "
|
||||||
|
<< "left pad width: " << l_len_ << " left pad token with separator: " << l_pad_with_sp_ << "\n"
|
||||||
|
<< "right pad width: " << r_len_ << " right pad token with separator: " << r_pad_with_sp_ << "\n"
|
||||||
|
<< "separator: " << separator_ << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NgramOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput(), "incorrect num of inputs\n");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(inputs[0].Rank() == 1, "ngram only works with 1-dim data\n");
|
||||||
|
dsize_t num_elements = ngrams_.size();
|
||||||
|
for (int32_t n : ngrams_) {
|
||||||
|
// here since rank == 1, NumOfElements == shape[0]. add padding length to string
|
||||||
|
int32_t len_with_padding = inputs[0].NumOfElements() + std::min(n - 1, l_len_) + std::min(n - 1, r_len_);
|
||||||
|
// if len_with_padding - n < 0, this would return an empty string
|
||||||
|
num_elements += std::max(len_with_padding - n, 0);
|
||||||
|
}
|
||||||
|
outputs.emplace_back(TensorShape({num_elements}));
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(outputs.size() == NumOutput(), "incorrect num of outputs\n");
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,74 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_TEXT_KERNELS_NGRAM_OP_H_
|
||||||
|
#define DATASET_TEXT_KERNELS_NGRAM_OP_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "dataset/core/tensor.h"
|
||||||
|
#include "dataset/kernels/tensor_op.h"
|
||||||
|
#include "dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
class NgramOp : public TensorOp {
|
||||||
|
public:
|
||||||
|
// Constructor of Ngram model
|
||||||
|
// @param const std::vector<int32_t> &ngrams
|
||||||
|
// @param int32_tl_len - padding length on the left
|
||||||
|
// @param int32_t r_len - padding length on the right
|
||||||
|
// @param const std::string &l_pad - padding token on the left
|
||||||
|
// @param const std::string &r_pad - padding token on the right
|
||||||
|
// @param const std::string &separator - use to join strings
|
||||||
|
NgramOp(const std::vector<int32_t> &ngrams, int32_t l_len, int32_t r_len, const std::string &l_pad,
|
||||||
|
const std::string &r_pad, const std::string &separator);
|
||||||
|
|
||||||
|
// perform ngram model on each tensor
|
||||||
|
// @param const std::shared_ptr<Tensor> &input
|
||||||
|
// @param std::shared_ptr<Tensor> *output
|
||||||
|
// @return error code
|
||||||
|
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||||
|
|
||||||
|
// destructor
|
||||||
|
~NgramOp() override = default;
|
||||||
|
|
||||||
|
// @param std::vector<TensorShape> &inputs - shape of input tensors
|
||||||
|
// @param std::vector<TensorShape> &outputs - shape of output tensors
|
||||||
|
// @return error code
|
||||||
|
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||||
|
|
||||||
|
// print arg for debugging
|
||||||
|
// @param std::ostream &out
|
||||||
|
void Print(std::ostream &out) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<int32_t> ngrams_; // list of n grams
|
||||||
|
int32_t l_len_; // left padding length
|
||||||
|
int32_t r_len_; // right padding length
|
||||||
|
std::string l_pad_with_sp_; // left padding appended with separator
|
||||||
|
std::string r_pad_with_sp_; // right padding appended with separator
|
||||||
|
std::string separator_; // separator
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_TEXT_KERNELS_NGRAM_OP_H_
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef DATASET_NLP_VOCAB_H_
|
#ifndef DATASET_TEXT_VOCAB_H_
|
||||||
#define DATASET_NLP_VOCAB_H_
|
#define DATASET_TEXT_VOCAB_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -87,4 +87,4 @@ class Vocab {
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // DATASET_NLP_VOCAB_H_
|
#endif // DATASET_TEXT_VOCAB_H_
|
||||||
|
|
|
@ -15,5 +15,5 @@
|
||||||
"""
|
"""
|
||||||
mindspore.dataset.text
|
mindspore.dataset.text
|
||||||
"""
|
"""
|
||||||
from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer
|
from .transforms import Lookup, JiebaTokenizer, UnicodeCharTokenizer, Ngram
|
||||||
from .utils import to_str, to_bytes, JiebaMode, Vocab
|
from .utils import to_str, to_bytes, JiebaMode, Vocab
|
||||||
|
|
|
@ -22,7 +22,7 @@ import mindspore._c_dataengine as cde
|
||||||
|
|
||||||
from .utils import JiebaMode
|
from .utils import JiebaMode
|
||||||
from .validators import check_lookup, check_jieba_add_dict, \
|
from .validators import check_lookup, check_jieba_add_dict, \
|
||||||
check_jieba_add_word, check_jieba_init
|
check_jieba_add_word, check_jieba_init, check_ngram
|
||||||
|
|
||||||
|
|
||||||
class Lookup(cde.LookupOp):
|
class Lookup(cde.LookupOp):
|
||||||
|
@ -41,6 +41,27 @@ class Lookup(cde.LookupOp):
|
||||||
super().__init__(vocab, unknown)
|
super().__init__(vocab, unknown)
|
||||||
|
|
||||||
|
|
||||||
|
class Ngram(cde.NgramOp):
|
||||||
|
"""
|
||||||
|
TensorOp to generate n-gram from a 1-D string Tensor
|
||||||
|
Refer to https://en.wikipedia.org/wiki/N-gram#Examples for an explanation of what n-gram is.
|
||||||
|
Args:
|
||||||
|
n(int or list): n in n-gram, n >= 1. n is a list of positive integers, for e.g. n=[4,3], The result
|
||||||
|
would be a 4-gram followed by a 3-gram in the same tensor.
|
||||||
|
left_pad(tuple, optional): ("pad_token",pad_width). Padding performed on left side of the sequence. pad_width
|
||||||
|
will be capped at n-1. left_pad=("_",2) would pad left side of the sequence with "__". (Default is None)
|
||||||
|
right_pad(tuple, optional): ("pad_token",pad_width). Padding performed on right side of the sequence. pad_width
|
||||||
|
will be capped at n-1. right_pad=("-":2) would pad right side of the sequence with "--". (Default is None)
|
||||||
|
separator(str,optional): symbol used to join strings together. for e.g. if 2-gram the ["mindspore", "amazing"]
|
||||||
|
with separator="-" the result would be ["mindspore-amazing"]. (Default is None which means whitespace is used)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_ngram
|
||||||
|
def __init__(self, n, left_pad=None, right_pad=None, separator=None):
|
||||||
|
super().__init__(ngrams=n, l_pad_len=left_pad[1], r_pad_len=right_pad[1], l_pad_token=left_pad[0],
|
||||||
|
r_pad_token=right_pad[0], separator=separator)
|
||||||
|
|
||||||
|
|
||||||
DE_C_INTER_JIEBA_MODE = {
|
DE_C_INTER_JIEBA_MODE = {
|
||||||
JiebaMode.MIX: cde.JiebaMode.DE_JIEBA_MIX,
|
JiebaMode.MIX: cde.JiebaMode.DE_JIEBA_MIX,
|
||||||
JiebaMode.MP: cde.JiebaMode.DE_JIEBA_MP,
|
JiebaMode.MP: cde.JiebaMode.DE_JIEBA_MP,
|
||||||
|
|
|
@ -34,9 +34,11 @@ def check_lookup(method):
|
||||||
if "unknown" in kwargs:
|
if "unknown" in kwargs:
|
||||||
unknown = kwargs.get("unknown")
|
unknown = kwargs.get("unknown")
|
||||||
if unknown is not None:
|
if unknown is not None:
|
||||||
assert isinstance(unknown, int) and unknown >= 0, "unknown needs to be a non-negative integer"
|
if not (isinstance(unknown, int) and unknown >= 0):
|
||||||
|
raise ValueError("unknown needs to be a non-negative integer")
|
||||||
|
|
||||||
assert isinstance(vocab, cde.Vocab), "vocab is not an instance of cde.Vocab"
|
if not isinstance(vocab, cde.Vocab):
|
||||||
|
raise ValueError("vocab is not an instance of cde.Vocab")
|
||||||
|
|
||||||
kwargs["vocab"] = vocab
|
kwargs["vocab"] = vocab
|
||||||
kwargs["unknown"] = unknown
|
kwargs["unknown"] = unknown
|
||||||
|
@ -58,13 +60,17 @@ def check_from_file(method):
|
||||||
if "vocab_size" in kwargs:
|
if "vocab_size" in kwargs:
|
||||||
vocab_size = kwargs.get("vocab_size")
|
vocab_size = kwargs.get("vocab_size")
|
||||||
|
|
||||||
assert isinstance(file_path, str), "file_path needs to be str"
|
if not isinstance(file_path, str):
|
||||||
|
raise ValueError("file_path needs to be str")
|
||||||
|
|
||||||
if delimiter is not None:
|
if delimiter is not None:
|
||||||
assert isinstance(delimiter, str), "delimiter needs to be str"
|
if not isinstance(delimiter, str):
|
||||||
|
raise ValueError("delimiter needs to be str")
|
||||||
else:
|
else:
|
||||||
delimiter = ""
|
delimiter = ""
|
||||||
if vocab_size is not None:
|
if vocab_size is not None:
|
||||||
assert isinstance(vocab_size, int) and vocab_size > 0, "vocab size needs to be a positive integer"
|
if not (isinstance(vocab_size, int) and vocab_size > 0):
|
||||||
|
raise ValueError("vocab size needs to be a positive integer")
|
||||||
else:
|
else:
|
||||||
vocab_size = -1
|
vocab_size = -1
|
||||||
kwargs["file_path"] = file_path
|
kwargs["file_path"] = file_path
|
||||||
|
@ -83,9 +89,11 @@ def check_from_list(method):
|
||||||
word_list, = (list(args) + [None])[:1]
|
word_list, = (list(args) + [None])[:1]
|
||||||
if "word_list" in kwargs:
|
if "word_list" in kwargs:
|
||||||
word_list = kwargs.get("word_list")
|
word_list = kwargs.get("word_list")
|
||||||
assert isinstance(word_list, list), "word_list needs to be a list of words"
|
if not isinstance(word_list, list):
|
||||||
|
raise ValueError("word_list needs to be a list of words")
|
||||||
for word in word_list:
|
for word in word_list:
|
||||||
assert isinstance(word, str), "each word in word list needs to be type str"
|
if not isinstance(word, str):
|
||||||
|
raise ValueError("each word in word list needs to be type str")
|
||||||
|
|
||||||
kwargs["word_list"] = word_list
|
kwargs["word_list"] = word_list
|
||||||
return method(self, **kwargs)
|
return method(self, **kwargs)
|
||||||
|
@ -101,10 +109,13 @@ def check_from_dict(method):
|
||||||
word_dict, = (list(args) + [None])[:1]
|
word_dict, = (list(args) + [None])[:1]
|
||||||
if "word_dict" in kwargs:
|
if "word_dict" in kwargs:
|
||||||
word_dict = kwargs.get("word_dict")
|
word_dict = kwargs.get("word_dict")
|
||||||
assert isinstance(word_dict, dict), "word_dict needs to be a list of word,id pairs"
|
if not isinstance(word_dict, dict):
|
||||||
|
raise ValueError("word_dict needs to be a list of word,id pairs")
|
||||||
for word, word_id in word_dict.items():
|
for word, word_id in word_dict.items():
|
||||||
assert isinstance(word, str), "each word in word_dict needs to be type str"
|
if not isinstance(word, str):
|
||||||
assert isinstance(word_id, int) and word_id >= 0, "each word id needs to be positive integer"
|
raise ValueError("each word in word_dict needs to be type str")
|
||||||
|
if not (isinstance(word_id, int) and word_id >= 0):
|
||||||
|
raise ValueError("each word id needs to be positive integer")
|
||||||
kwargs["word_dict"] = word_dict
|
kwargs["word_dict"] = word_dict
|
||||||
return method(self, **kwargs)
|
return method(self, **kwargs)
|
||||||
|
|
||||||
|
@ -173,3 +184,61 @@ def check_jieba_add_dict(method):
|
||||||
return method(self, **kwargs)
|
return method(self, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_ngram(method):
|
||||||
|
"""A wrapper that wrap a parameter checker to the original function(crop operation)."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
n, left_pad, right_pad, separator = (list(args) + 4 * [None])[:4]
|
||||||
|
if "n" in kwargs:
|
||||||
|
n = kwargs.get("n")
|
||||||
|
if "left_pad" in kwargs:
|
||||||
|
left_pad = kwargs.get("left_pad")
|
||||||
|
if "right_pad" in kwargs:
|
||||||
|
right_pad = kwargs.get("right_pad")
|
||||||
|
if "separator" in kwargs:
|
||||||
|
separator = kwargs.get("separator")
|
||||||
|
|
||||||
|
if isinstance(n, int):
|
||||||
|
n = [n]
|
||||||
|
|
||||||
|
if not (isinstance(n, list) and n != []):
|
||||||
|
raise ValueError("n needs to be a non-empty list of positive integers")
|
||||||
|
|
||||||
|
for gram in n:
|
||||||
|
if not (isinstance(gram, int) and gram > 0):
|
||||||
|
raise ValueError("n in ngram needs to be a positive number\n")
|
||||||
|
|
||||||
|
if left_pad is None:
|
||||||
|
left_pad = ("", 0)
|
||||||
|
|
||||||
|
if right_pad is None:
|
||||||
|
right_pad = ("", 0)
|
||||||
|
|
||||||
|
if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance(
|
||||||
|
left_pad[1], int)):
|
||||||
|
raise ValueError("left_pad needs to be a tuple of (str, int) str is pad token and int is pad_width")
|
||||||
|
|
||||||
|
if not (isinstance(right_pad, tuple) and len(right_pad) == 2 and isinstance(right_pad[0], str) and isinstance(
|
||||||
|
right_pad[1], int)):
|
||||||
|
raise ValueError("right_pad needs to be a tuple of (str, int) str is pad token and int is pad_width")
|
||||||
|
|
||||||
|
if not (left_pad[1] >= 0 and right_pad[1] >= 0):
|
||||||
|
raise ValueError("padding width need to be positive numbers")
|
||||||
|
|
||||||
|
if separator is None:
|
||||||
|
separator = " "
|
||||||
|
|
||||||
|
if not isinstance(separator, str):
|
||||||
|
raise ValueError("separator needs to be a string")
|
||||||
|
|
||||||
|
kwargs["n"] = n
|
||||||
|
kwargs["left_pad"] = left_pad
|
||||||
|
kwargs["right_pad"] = right_pad
|
||||||
|
kwargs["separator"] = separator
|
||||||
|
|
||||||
|
return method(self, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Testing NgramOP in DE
|
||||||
|
"""
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.text as nlp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_ngrams():
|
||||||
|
""" test n-gram where n is a list of integers"""
|
||||||
|
plates_mottos = ["WildRose Country", "Canada's Ocean Playground", "Land of Living Skies"]
|
||||||
|
n_gram_mottos = []
|
||||||
|
n_gram_mottos.append(
|
||||||
|
['WildRose', 'Country', '_ WildRose', 'WildRose Country', 'Country _', '_ _ WildRose', '_ WildRose Country',
|
||||||
|
'WildRose Country _', 'Country _ _'])
|
||||||
|
n_gram_mottos.append(
|
||||||
|
["Canada's", 'Ocean', 'Playground', "_ Canada's", "Canada's Ocean", 'Ocean Playground', 'Playground _',
|
||||||
|
"_ _ Canada's", "_ Canada's Ocean", "Canada's Ocean Playground", 'Ocean Playground _', 'Playground _ _'])
|
||||||
|
n_gram_mottos.append(
|
||||||
|
['Land', 'of', 'Living', 'Skies', '_ Land', 'Land of', 'of Living', 'Living Skies', 'Skies _', '_ _ Land',
|
||||||
|
'_ Land of', 'Land of Living', 'of Living Skies', 'Living Skies _', 'Skies _ _'])
|
||||||
|
|
||||||
|
def gen(texts):
|
||||||
|
for line in texts:
|
||||||
|
yield (np.array(line.split(" "), dtype='S'),)
|
||||||
|
|
||||||
|
dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"])
|
||||||
|
dataset = dataset.map(input_columns=["text"], operations=nlp.Ngram([1, 2, 3], ("_", 2), ("_", 2), " "))
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for data in dataset.create_dict_iterator():
|
||||||
|
assert [d.decode("utf8") for d in data["text"]] == n_gram_mottos[i]
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_ngram():
|
||||||
|
""" test simple gram with only one n value"""
|
||||||
|
plates_mottos = ["Friendly Manitoba", "Yours to Discover", "Land of Living Skies",
|
||||||
|
"Birthplace of the Confederation"]
|
||||||
|
n_gram_mottos = [[]]
|
||||||
|
n_gram_mottos.append(["Yours to Discover"])
|
||||||
|
n_gram_mottos.append(['Land of Living', 'of Living Skies'])
|
||||||
|
n_gram_mottos.append(['Birthplace of the', 'of the Confederation'])
|
||||||
|
|
||||||
|
def gen(texts):
|
||||||
|
for line in texts:
|
||||||
|
yield (np.array(line.split(" "), dtype='S'),)
|
||||||
|
|
||||||
|
dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"])
|
||||||
|
dataset = dataset.map(input_columns=["text"], operations=nlp.Ngram(3, separator=None))
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for data in dataset.create_dict_iterator():
|
||||||
|
assert [d.decode("utf8") for d in data["text"]] == n_gram_mottos[i], i
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_corner_cases():
|
||||||
|
""" testing various corner cases and exceptions"""
|
||||||
|
|
||||||
|
def test_config(input_line, output_line, n, l_pad=None, r_pad=None, sep=None):
|
||||||
|
def gen(text):
|
||||||
|
yield (np.array(text.split(" "), dtype='S'),)
|
||||||
|
|
||||||
|
dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"])
|
||||||
|
dataset = dataset.map(input_columns=["text"], operations=nlp.Ngram(n, l_pad, r_pad, separator=sep))
|
||||||
|
for data in dataset.create_dict_iterator():
|
||||||
|
assert [d.decode("utf8") for d in data["text"]] == output_line, output_line
|
||||||
|
|
||||||
|
# test empty separator
|
||||||
|
test_config("Beautiful British Columbia", ['BeautifulBritish', 'BritishColumbia'], 2, sep="")
|
||||||
|
# test separator with longer length
|
||||||
|
test_config("Beautiful British Columbia", ['Beautiful^-^British^-^Columbia'], 3, sep="^-^")
|
||||||
|
# test left pad != right pad
|
||||||
|
test_config("Lone Star", ['The Lone Star State'], 4, ("The", 1), ("State", 1))
|
||||||
|
# test invalid n
|
||||||
|
try:
|
||||||
|
test_config("Yours to Discover", "", [0, [1]])
|
||||||
|
except Exception as e:
|
||||||
|
assert "ngram needs to be a positive number" in str(e)
|
||||||
|
# test empty n
|
||||||
|
try:
|
||||||
|
test_config("Yours to Discover", "", [])
|
||||||
|
except Exception as e:
|
||||||
|
assert "n needs to be a non-empty list" in str(e)
|
||||||
|
# test invalid pad
|
||||||
|
try:
|
||||||
|
test_config("Yours to Discover", "", [1], ("str", -1))
|
||||||
|
except Exception as e:
|
||||||
|
assert "padding width need to be positive numbers" in str(e)
|
||||||
|
# test invalid pad
|
||||||
|
try:
|
||||||
|
test_config("Yours to Discover", "", [1], ("str", "rts"))
|
||||||
|
except Exception as e:
|
||||||
|
assert "pad needs to be a tuple of (str, int)" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_multiple_ngrams()
|
||||||
|
test_simple_ngram()
|
||||||
|
test_corner_cases()
|
Loading…
Reference in New Issue