forked from mindspore-Ecosystem/mindspore
!11893 LookupOp: default value of unknown_token support NoneType
From: @luoyang42 Reviewed-by: @liucunwei,@pandoublefeng Signed-off-by: @liucunwei
This commit is contained in:
commit
0dcd94d717
|
@ -154,8 +154,8 @@ PYBIND_REGISTER(
|
|||
PYBIND_REGISTER(LookupOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<text::LookupOperation, TensorOperation, std::shared_ptr<text::LookupOperation>>(
|
||||
*m, "LookupOperation")
|
||||
.def(py::init([](const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
|
||||
const std::string &data_type) {
|
||||
.def(py::init([](const std::shared_ptr<Vocab> &vocab,
|
||||
const std::optional<std::string> &unknown_token, const std::string &data_type) {
|
||||
auto lookup = std::make_shared<text::LookupOperation>(vocab, unknown_token, data_type);
|
||||
THROW_IF_ERROR(lookup->ValidateParams());
|
||||
return lookup;
|
||||
|
|
|
@ -87,8 +87,8 @@ std::shared_ptr<JiebaTokenizerOperation> JiebaTokenizer(const std::string &hmm_p
|
|||
return op->ValidateParams() ? op : nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
|
||||
const std::string &data_type) {
|
||||
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab,
|
||||
const std::optional<std::string> &unknown_token, const std::string &data_type) {
|
||||
auto op = std::make_shared<LookupOperation>(vocab, unknown_token, data_type);
|
||||
|
||||
return op->ValidateParams() ? op : nullptr;
|
||||
|
@ -340,7 +340,7 @@ Status JiebaTokenizerOperation::AddWord(const std::string &word, int64_t freq) {
|
|||
}
|
||||
|
||||
// LookupOperation
|
||||
LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
|
||||
LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token,
|
||||
const std::string &data_type)
|
||||
: vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {}
|
||||
|
||||
|
@ -352,10 +352,10 @@ Status LookupOperation::ValidateParams() {
|
|||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (!unknown_token_.empty()) {
|
||||
default_id_ = vocab_->Lookup(unknown_token_);
|
||||
if (unknown_token_ != std::nullopt) {
|
||||
default_id_ = vocab_->Lookup(*unknown_token_);
|
||||
if (default_id_ == Vocab::kNoTokenExists) {
|
||||
std::string err_msg = "Lookup: \"" + unknown_token_ + "\" doesn't exist in vocab.";
|
||||
std::string err_msg = "Lookup: \"" + *unknown_token_ + "\" doesn't exist in vocab.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TEXT_H_
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
@ -143,11 +144,13 @@ std::shared_ptr<JiebaTokenizerOperation> JiebaTokenizer(const std::string &hmm_p
|
|||
/// \brief Lookup operator that looks up a word to an id.
|
||||
/// \param[in] vocab a Vocab object.
|
||||
/// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov).
|
||||
/// If unknown_token is oov, runtime error will be thrown.
|
||||
/// If unknown_token is oov, runtime error will be thrown. If unknown_token is {}, which means that not to
|
||||
// specify unknown_token when word being out of Vocabulary (default={}).
|
||||
/// \param[in] data_type type of the tensor after lookup, typically int32.
|
||||
/// \return Shared pointer to the current TensorOperation.
|
||||
|
||||
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
|
||||
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab,
|
||||
const std::optional<std::string> &unknown_token = {},
|
||||
const std::string &data_type = "int32");
|
||||
|
||||
/// \brief TensorOp to generate n-gram from a 1-D string Tensor.
|
||||
|
@ -343,7 +346,7 @@ class JiebaTokenizerOperation : public TensorOperation {
|
|||
|
||||
class LookupOperation : public TensorOperation {
|
||||
public:
|
||||
explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
|
||||
explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token,
|
||||
const std::string &data_type);
|
||||
|
||||
~LookupOperation();
|
||||
|
@ -356,7 +359,7 @@ class LookupOperation : public TensorOperation {
|
|||
|
||||
private:
|
||||
std::shared_ptr<Vocab> vocab_;
|
||||
std::string unknown_token_;
|
||||
std::optional<std::string> unknown_token_;
|
||||
int32_t default_id_;
|
||||
std::string data_type_;
|
||||
};
|
||||
|
|
|
@ -295,7 +295,7 @@ class Lookup(TextTensorOperation):
|
|||
@check_lookup
|
||||
def __init__(self, vocab, unknown_token=None, data_type=mstype.int32):
|
||||
self.vocab = vocab
|
||||
self.unknown_token = replace_none(unknown_token, '')
|
||||
self.unknown_token = unknown_token
|
||||
self.data_type = data_type
|
||||
|
||||
def parse(self):
|
||||
|
|
|
@ -119,6 +119,31 @@ def test_from_list():
|
|||
assert "is not of type" in test_config("w1", ["w1", "w2"], ["s1"], True, 123)
|
||||
|
||||
|
||||
def test_from_list_lookup_empty_string():
|
||||
# "" is a valid word in vocab, which can be looked up by LookupOp
|
||||
vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", ""], True)
|
||||
lookup = text.Lookup(vocab, "")
|
||||
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
||||
data = data.map(operations=lookup, input_columns=["text"])
|
||||
ind = 0
|
||||
res = [2, 1, 4, 5, 6, 7]
|
||||
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
assert d["text"] == res[ind], ind
|
||||
ind += 1
|
||||
|
||||
# unknown_token of LookUp is None, it will convert to std::nullopt in C++,
|
||||
# so it has nothing to do with "" in vocab and C++ will skip looking up unknown_token
|
||||
vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", ""], True)
|
||||
lookup = text.Lookup(vocab)
|
||||
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
||||
data = data.map(operations=lookup, input_columns=["text"])
|
||||
try:
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
assert "token: \"is\" doesn't exist in vocab and no unknown token is specified" in str(e)
|
||||
|
||||
|
||||
def test_from_file():
|
||||
def gen(texts):
|
||||
for word in texts.split(" "):
|
||||
|
|
Loading…
Reference in New Issue