forked from mindspore-Ecosystem/mindspore
!5791 Add data_type parameter to Lookup API
Merge pull request !5791 from ZiruiWu/lookup_add_type
This commit is contained in:
commit
8a20b5d784
|
@ -121,12 +121,13 @@ PYBIND_REGISTER(UnicodeCharTokenizerOp, 1, ([](const py::module *m) {
|
||||||
|
|
||||||
PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) {
|
PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) {
|
||||||
(void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(*m, "LookupOp")
|
(void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(*m, "LookupOp")
|
||||||
.def(py::init([](std::shared_ptr<Vocab> vocab, const py::object &py_word) {
|
.def(py::init([](std::shared_ptr<Vocab> vocab, const py::object &py_word,
|
||||||
|
const DataType &data_type) {
|
||||||
if (vocab == nullptr) {
|
if (vocab == nullptr) {
|
||||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null."));
|
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null."));
|
||||||
}
|
}
|
||||||
if (py_word.is_none()) {
|
if (py_word.is_none()) {
|
||||||
return std::make_shared<LookupOp>(vocab, Vocab::kNoTokenExists);
|
return std::make_shared<LookupOp>(vocab, Vocab::kNoTokenExists, data_type);
|
||||||
}
|
}
|
||||||
std::string word = py::reinterpret_borrow<py::str>(py_word);
|
std::string word = py::reinterpret_borrow<py::str>(py_word);
|
||||||
WordIdType default_id = vocab->Lookup(word);
|
WordIdType default_id = vocab->Lookup(word);
|
||||||
|
@ -134,7 +135,7 @@ PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) {
|
||||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError,
|
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError,
|
||||||
"default unknown token: " + word + " doesn't exist in vocab."));
|
"default unknown token: " + word + " doesn't exist in vocab."));
|
||||||
}
|
}
|
||||||
return std::make_shared<LookupOp>(vocab, default_id);
|
return std::make_shared<LookupOp>(vocab, default_id, data_type);
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|
|
@ -22,8 +22,9 @@ namespace dataset {
|
||||||
namespace api {
|
namespace api {
|
||||||
namespace text {
|
namespace text {
|
||||||
|
|
||||||
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token) {
|
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
|
||||||
auto op = std::make_shared<LookupOperation>(vocab, unknown_token);
|
const DataType &data_type) {
|
||||||
|
auto op = std::make_shared<LookupOperation>(vocab, unknown_token, data_type);
|
||||||
|
|
||||||
if (!op->ValidateParams()) {
|
if (!op->ValidateParams()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -32,8 +33,9 @@ std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, con
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupOperation
|
// LookupOperation
|
||||||
LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token)
|
LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
|
||||||
: vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists) {}
|
const DataType &data_type)
|
||||||
|
: vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {}
|
||||||
|
|
||||||
bool LookupOperation::ValidateParams() {
|
bool LookupOperation::ValidateParams() {
|
||||||
if (vocab_ == nullptr) {
|
if (vocab_ == nullptr) {
|
||||||
|
@ -54,7 +56,7 @@ bool LookupOperation::ValidateParams() {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<TensorOp> LookupOperation::Build() {
|
std::shared_ptr<TensorOp> LookupOperation::Build() {
|
||||||
std::shared_ptr<LookupOp> tensor_op = std::make_shared<LookupOp>(vocab_, default_id_);
|
std::shared_ptr<LookupOp> tensor_op = std::make_shared<LookupOp>(vocab_, default_id_, data_type_);
|
||||||
return tensor_op;
|
return tensor_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,9 +20,11 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "minddata/dataset/core/constants.h"
|
#include "minddata/dataset/core/constants.h"
|
||||||
#include "minddata/dataset/include/transforms.h"
|
#include "minddata/dataset/include/transforms.h"
|
||||||
#include "minddata/dataset/text/vocab.h"
|
#include "minddata/dataset/text/vocab.h"
|
||||||
|
#include "mindspore/ccsrc/minddata/dataset/core/data_type.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -37,15 +39,18 @@ class LookupOperation;
|
||||||
/// \brief Lookup operator that looks up a word to an id.
|
/// \brief Lookup operator that looks up a word to an id.
|
||||||
/// \param[in] vocab a Vocab object.
|
/// \param[in] vocab a Vocab object.
|
||||||
/// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov).
|
/// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov).
|
||||||
/// If unknown_token is oov, runtime error will be thrown
|
/// If unknown_token is oov, runtime error will be thrown.
|
||||||
|
/// \param[in] DataType type of the tensor after lookup, typically int32.
|
||||||
/// \return Shared pointer to the current TensorOperation.
|
/// \return Shared pointer to the current TensorOperation.
|
||||||
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token);
|
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
|
||||||
|
const mindspore::dataset::DataType &data_type = DataType("int32"));
|
||||||
|
|
||||||
/* ####################################### Derived TensorOperation classes ################################# */
|
/* ####################################### Derived TensorOperation classes ################################# */
|
||||||
|
|
||||||
class LookupOperation : public TensorOperation {
|
class LookupOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token);
|
explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
|
||||||
|
const DataType &data_type);
|
||||||
|
|
||||||
~LookupOperation() = default;
|
~LookupOperation() = default;
|
||||||
|
|
||||||
|
@ -57,6 +62,7 @@ class LookupOperation : public TensorOperation {
|
||||||
std::shared_ptr<Vocab> vocab_;
|
std::shared_ptr<Vocab> vocab_;
|
||||||
std::string unknown_token_;
|
std::string unknown_token_;
|
||||||
int32_t default_id_;
|
int32_t default_id_;
|
||||||
|
DataType data_type_;
|
||||||
};
|
};
|
||||||
} // namespace text
|
} // namespace text
|
||||||
} // namespace api
|
} // namespace api
|
||||||
|
|
|
@ -13,15 +13,16 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "minddata/dataset/text/kernels/lookup_op.h"
|
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||||
|
#include "minddata/dataset/text/kernels/lookup_op.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
LookupOp::LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id)
|
LookupOp::LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id, const DataType &data_type)
|
||||||
: vocab_(vocab), default_id_(default_id), type_(DataType("int32")) {}
|
: vocab_(vocab), default_id_(default_id), type_(data_type) {}
|
||||||
|
|
||||||
Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
IO_CHECK(input, output);
|
IO_CHECK(input, output);
|
||||||
|
@ -37,6 +38,14 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
|
||||||
"Lookup Error: token: " + std::string(*itr) + " doesn't exist in vocab and no unknown token is specified.");
|
"Lookup Error: token: " + std::string(*itr) + " doesn't exist in vocab and no unknown token is specified.");
|
||||||
}
|
}
|
||||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(word_ids, input->shape(), output));
|
RETURN_IF_NOT_OK(Tensor::CreateFromVector(word_ids, input->shape(), output));
|
||||||
|
|
||||||
|
// type cast to user's requirements if what user wants isn't int32_t
|
||||||
|
if ((*output)->type() != type_) {
|
||||||
|
std::shared_ptr<Tensor> cast_to;
|
||||||
|
RETURN_IF_NOT_OK(TypeCast(*output, &cast_to, type_));
|
||||||
|
*output = cast_to;
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
Status LookupOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
Status LookupOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||||
|
|
|
@ -18,9 +18,9 @@
|
||||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_LOOKUP_OP_H_
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_LOOKUP_OP_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
|
||||||
#include <utility>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "minddata/dataset/core/tensor.h"
|
#include "minddata/dataset/core/tensor.h"
|
||||||
#include "minddata/dataset/kernels/tensor_op.h"
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
@ -31,26 +31,27 @@ namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
class LookupOp : public TensorOp {
|
class LookupOp : public TensorOp {
|
||||||
public:
|
public:
|
||||||
// constructor for lookup, takes in a vocab object
|
/// \brief constructor for lookup, takes in a vocab object.
|
||||||
// @param std::shared_ptr<Vocab> vocab -
|
/// \param[in] std::shared_ptr<Vocab> vocab - vocab used for lookup.
|
||||||
// @param WordIdType default_id, id to lookup if a word is not in vocab
|
/// \param[in] WordIdType default_id, id to lookup if a word is not in vocab.
|
||||||
explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = 1);
|
/// \param[in] DataType type of the tensor after lookup, mostly int32.
|
||||||
|
explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id, const DataType &data_type);
|
||||||
|
|
||||||
~LookupOp() = default;
|
~LookupOp() = default;
|
||||||
|
|
||||||
// perform actual lookup on each tensor
|
/// \brief perform actual lookup on each tensor.
|
||||||
// @param const std::shared_ptr<Tensor> &input
|
/// \param[in] const std::shared_ptr<Tensor> &input
|
||||||
// @param std::shared_ptr<Tensor> *output
|
/// \param[in] std::shared_ptr<Tensor> *output
|
||||||
// @return error code
|
/// \return[out] error code.
|
||||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||||
|
|
||||||
// print method
|
/// \brief print method.
|
||||||
// @param std::ostream out
|
/// \param[in] std::ostream out
|
||||||
void Print(std::ostream &out) const override;
|
void Print(std::ostream &out) const override;
|
||||||
|
|
||||||
// @param std::vector<DataType> &inputs -
|
/// \param[in] std::vector<DataType> &inputs -
|
||||||
// @param std::vector<DataType> &outputs -
|
/// \param[in] std::vector<DataType> &outputs -
|
||||||
// @return error code
|
/// \return[out] error code.
|
||||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||||
|
|
||||||
std::string Name() const override { return kLookupOp; }
|
std::string Name() const override { return kLookupOp; }
|
||||||
|
|
|
@ -49,6 +49,7 @@ import platform
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import mindspore._c_dataengine as cde
|
import mindspore._c_dataengine as cde
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPieceTokenizerLoadType
|
from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPieceTokenizerLoadType
|
||||||
from .validators import check_lookup, check_jieba_add_dict, \
|
from .validators import check_lookup, check_jieba_add_dict, \
|
||||||
|
@ -66,11 +67,12 @@ class Lookup(cde.LookupOp):
|
||||||
vocab(Vocab): a Vocab object.
|
vocab(Vocab): a Vocab object.
|
||||||
unknown_token(str, optional): word to use for lookup if the word being looked up is out of Vocabulary (oov).
|
unknown_token(str, optional): word to use for lookup if the word being looked up is out of Vocabulary (oov).
|
||||||
If unknown_token is oov, runtime error will be thrown (default=None).
|
If unknown_token is oov, runtime error will be thrown (default=None).
|
||||||
|
data_type (mindspore.dtype, optional): mindspore.dtype lookup maps string to (default=mstype.int32)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@check_lookup
|
@check_lookup
|
||||||
def __init__(self, vocab, unknown_token=None):
|
def __init__(self, vocab, unknown_token=None, data_type=mstype.int32):
|
||||||
super().__init__(vocab, unknown_token)
|
super().__init__(vocab, unknown_token, mstype_to_detype(data_type))
|
||||||
|
|
||||||
|
|
||||||
class SlidingWindow(cde.SlidingWindowOp):
|
class SlidingWindow(cde.SlidingWindowOp):
|
||||||
|
@ -103,7 +105,6 @@ class SlidingWindow(cde.SlidingWindowOp):
|
||||||
super().__init__(width, axis)
|
super().__init__(width, axis)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Ngram(cde.NgramOp):
|
class Ngram(cde.NgramOp):
|
||||||
"""
|
"""
|
||||||
TensorOp to generate n-gram from a 1-D string Tensor.
|
TensorOp to generate n-gram from a 1-D string Tensor.
|
||||||
|
|
|
@ -44,12 +44,13 @@ def check_lookup(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
[vocab, unknown_token], _ = parse_user_args(method, *args, **kwargs)
|
[vocab, unknown_token, data_type], _ = parse_user_args(method, *args, **kwargs)
|
||||||
|
|
||||||
if unknown_token is not None:
|
if unknown_token is not None:
|
||||||
type_check(unknown_token, (str,), "unknown_token")
|
type_check(unknown_token, (str,), "unknown_token")
|
||||||
|
|
||||||
type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.")
|
type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.")
|
||||||
|
type_check(data_type, (typing.Type,), "data_type")
|
||||||
|
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
@ -327,6 +328,7 @@ def check_from_dataset(method):
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
def check_slidingwindow(method):
|
def check_slidingwindow(method):
|
||||||
"""A wrapper that wraps a parameter checker to the original function(sliding window operation)."""
|
"""A wrapper that wraps a parameter checker to the original function(sliding window operation)."""
|
||||||
|
|
||||||
|
@ -339,6 +341,7 @@ def check_slidingwindow(method):
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
def check_ngram(method):
|
def check_ngram(method):
|
||||||
"""A wrapper that wraps a parameter checker to the original function."""
|
"""A wrapper that wraps a parameter checker to the original function."""
|
||||||
|
|
||||||
|
|
|
@ -26,9 +26,10 @@
|
||||||
#include "minddata/dataset/include/text.h"
|
#include "minddata/dataset/include/text.h"
|
||||||
|
|
||||||
using namespace mindspore::dataset::api;
|
using namespace mindspore::dataset::api;
|
||||||
|
using mindspore::dataset::DataType;
|
||||||
using mindspore::dataset::ShuffleMode;
|
using mindspore::dataset::ShuffleMode;
|
||||||
using mindspore::dataset::Tensor;
|
|
||||||
using mindspore::dataset::Status;
|
using mindspore::dataset::Status;
|
||||||
|
using mindspore::dataset::Tensor;
|
||||||
using mindspore::dataset::Vocab;
|
using mindspore::dataset::Vocab;
|
||||||
|
|
||||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||||
|
@ -50,7 +51,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOp) {
|
||||||
EXPECT_EQ(s, Status::OK());
|
EXPECT_EQ(s, Status::OK());
|
||||||
|
|
||||||
// Create Lookup operation on ds
|
// Create Lookup operation on ds
|
||||||
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>");
|
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>", DataType("int32"));
|
||||||
EXPECT_NE(lookup, nullptr);
|
EXPECT_NE(lookup, nullptr);
|
||||||
|
|
||||||
// Create Map operation on ds
|
// Create Map operation on ds
|
||||||
|
@ -94,7 +95,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail1) {
|
||||||
|
|
||||||
// Create lookup op for ds
|
// Create lookup op for ds
|
||||||
// Expected failure: "<unk>" is not a word of vocab
|
// Expected failure: "<unk>" is not a word of vocab
|
||||||
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>");
|
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>", DataType("int32"));
|
||||||
EXPECT_EQ(lookup, nullptr);
|
EXPECT_EQ(lookup, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,7 +106,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail2) {
|
||||||
|
|
||||||
// Create lookup op
|
// Create lookup op
|
||||||
// Expected failure: vocab is null
|
// Expected failure: vocab is null
|
||||||
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "");
|
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "", DataType("int32"));
|
||||||
EXPECT_EQ(lookup, nullptr);
|
EXPECT_EQ(lookup, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,7 +127,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpWithEmptyUnknownToken) {
|
||||||
|
|
||||||
// Create Lookup operation on ds
|
// Create Lookup operation on ds
|
||||||
// Expected failure: "" is not a word of vocab
|
// Expected failure: "" is not a word of vocab
|
||||||
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "");
|
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "", DataType("int32"));
|
||||||
EXPECT_EQ(lookup, nullptr);
|
EXPECT_EQ(lookup, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -148,7 +149,7 @@ TEST_F(MindDataTestPipeline, TestVocabFromDataset) {
|
||||||
EXPECT_EQ(home_index, 4);
|
EXPECT_EQ(home_index, 4);
|
||||||
|
|
||||||
// Create Lookup operation on ds
|
// Create Lookup operation on ds
|
||||||
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>");
|
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>", DataType("int32"));
|
||||||
EXPECT_NE(lookup, nullptr);
|
EXPECT_NE(lookup, nullptr);
|
||||||
|
|
||||||
// Create Map operation on ds
|
// Create Map operation on ds
|
||||||
|
@ -212,12 +213,15 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetDefault) {
|
||||||
|
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
std::vector<int32_t> expected = {2, 3, 1, 4, 5, 0};
|
std::vector<int32_t> expected = {2, 3, 1, 4, 5, 0};
|
||||||
|
std::vector<int64_t> not_expected = {2, 3, 1, 4, 5, 0};
|
||||||
while (row.size() != 0) {
|
while (row.size() != 0) {
|
||||||
auto ind = row["text"];
|
auto ind = row["text"];
|
||||||
MS_LOG(INFO) << ind->shape() << " " << *ind;
|
MS_LOG(INFO) << ind->shape() << " " << *ind;
|
||||||
std::shared_ptr<Tensor> expected_item;
|
std::shared_ptr<Tensor> expected_item, not_expected_item;
|
||||||
Tensor::CreateScalar(expected[i], &expected_item);
|
Tensor::CreateScalar(expected[i], &expected_item);
|
||||||
|
Tensor::CreateScalar(not_expected[i], ¬_expected_item);
|
||||||
EXPECT_EQ(*ind, *expected_item);
|
EXPECT_EQ(*ind, *expected_item);
|
||||||
|
EXPECT_NE(*ind, *not_expected_item);
|
||||||
iter->GetNextRow(&row);
|
iter->GetNextRow(&row);
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
@ -233,8 +237,8 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail1) {
|
||||||
|
|
||||||
// Create vocab from dataset
|
// Create vocab from dataset
|
||||||
// Expected failure: top_k can not be negative
|
// Expected failure: top_k can not be negative
|
||||||
std::shared_ptr<Vocab> vocab = ds->BuildVocab({"text"}, {0, std::numeric_limits<int64_t>::max()},
|
std::shared_ptr<Vocab> vocab =
|
||||||
-2, {"<pad>", "<unk>"}, true);
|
ds->BuildVocab({"text"}, {0, std::numeric_limits<int64_t>::max()}, -2, {"<pad>", "<unk>"}, true);
|
||||||
EXPECT_EQ(vocab, nullptr);
|
EXPECT_EQ(vocab, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,9 +251,9 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail2) {
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
// Create vocab from dataset
|
// Create vocab from dataset
|
||||||
// Expected failure: requency_range [a,b] should be 0 <= a <= b
|
// Expected failure: frequency_range [a,b] should be 0 <= a <= b
|
||||||
std::shared_ptr<Vocab> vocab = ds->BuildVocab({"text"}, {4, 1},
|
std::shared_ptr<Vocab> vocab =
|
||||||
std::numeric_limits<int64_t>::max(), {"<pad>", "<unk>"}, true);
|
ds->BuildVocab({"text"}, {4, 1}, std::numeric_limits<int64_t>::max(), {"<pad>", "<unk>"}, true);
|
||||||
EXPECT_EQ(vocab, nullptr);
|
EXPECT_EQ(vocab, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -266,3 +270,52 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail3) {
|
||||||
std::shared_ptr<Vocab> vocab = ds->BuildVocab({"ColumnNotExist"});
|
std::shared_ptr<Vocab> vocab = ds->BuildVocab({"ColumnNotExist"});
|
||||||
EXPECT_EQ(vocab, nullptr);
|
EXPECT_EQ(vocab, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestPipeline, TestVocabFromDatasetInt64) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabFromDatasetInt64.";
|
||||||
|
|
||||||
|
// Create a TextFile dataset
|
||||||
|
std::string data_file = datasets_root_path_ + "/testVocab/words.txt";
|
||||||
|
std::shared_ptr<Dataset> ds = TextFile({data_file}, 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create vocab from dataset
|
||||||
|
std::shared_ptr<Vocab> vocab = ds->BuildVocab();
|
||||||
|
EXPECT_NE(vocab, nullptr);
|
||||||
|
|
||||||
|
// Check if vocab has words or not
|
||||||
|
int32_t home_index = vocab->Lookup("home");
|
||||||
|
EXPECT_EQ(home_index, 2);
|
||||||
|
|
||||||
|
// Create Lookup operation on ds
|
||||||
|
std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "home", DataType("int64"));
|
||||||
|
EXPECT_NE(lookup, nullptr);
|
||||||
|
|
||||||
|
// Create Map operation on ds
|
||||||
|
ds = ds->Map({lookup});
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row
|
||||||
|
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||||
|
iter->GetNextRow(&row);
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
std::vector<int64_t> expected = {2, 3, 1, 4, 5, 0};
|
||||||
|
std::vector<int8_t> not_expected = {2, 3, 1, 4, 5, 0};
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto ind = row["text"];
|
||||||
|
MS_LOG(INFO) << ind->shape() << " " << *ind;
|
||||||
|
std::shared_ptr<Tensor> expected_item, not_expected_item;
|
||||||
|
Tensor::CreateScalar(expected[i], &expected_item);
|
||||||
|
Tensor::CreateScalar(not_expected[i], ¬_expected_item);
|
||||||
|
EXPECT_EQ(*ind, *expected_item);
|
||||||
|
EXPECT_NE(*ind, *not_expected_item);
|
||||||
|
iter->GetNextRow(&row);
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
}
|
|
@ -17,6 +17,7 @@ import numpy as np
|
||||||
|
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
import mindspore.dataset.text as text
|
import mindspore.dataset.text as text
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
# this file contains "home is behind the world head" each word is 1 line
|
# this file contains "home is behind the world head" each word is 1 line
|
||||||
DATA_FILE = "../data/dataset/testVocab/words.txt"
|
DATA_FILE = "../data/dataset/testVocab/words.txt"
|
||||||
|
@ -137,6 +138,36 @@ def test_from_file():
|
||||||
assert "Input vocab_size must be greater than 0" in test_config("w1 w2", 0, [], True)
|
assert "Input vocab_size must be greater than 0" in test_config("w1 w2", 0, [], True)
|
||||||
assert "Input vocab_size must be greater than 0" in test_config("w1 w2", -1, [], True)
|
assert "Input vocab_size must be greater than 0" in test_config("w1 w2", -1, [], True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_cast_type():
|
||||||
|
def gen(texts):
|
||||||
|
for word in texts.split(" "):
|
||||||
|
yield (np.array(word, dtype='S'),)
|
||||||
|
|
||||||
|
def test_config(lookup_str, data_type=None):
|
||||||
|
try:
|
||||||
|
vocab = text.Vocab.from_list(["w1", "w2", "w3"], special_tokens=["<unk>"], special_first=True)
|
||||||
|
data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"])
|
||||||
|
# if data_type is None, test the default value of data_type
|
||||||
|
op = text.Lookup(vocab, "<unk>") if data_type is None else text.Lookup(vocab, "<unk>", data_type)
|
||||||
|
data = data.map(input_columns=["text"], operations=op)
|
||||||
|
res = []
|
||||||
|
for d in data.create_dict_iterator(num_epochs=1):
|
||||||
|
res.append(d["text"])
|
||||||
|
return res[0].dtype
|
||||||
|
except (ValueError, RuntimeError, TypeError) as e:
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
# test result is correct
|
||||||
|
assert test_config("w1", mstype.int8) == np.dtype("int8")
|
||||||
|
assert test_config("w2", mstype.int32) == np.dtype("int32")
|
||||||
|
assert test_config("w3", mstype.int64) == np.dtype("int64")
|
||||||
|
assert test_config("unk", mstype.float32) != np.dtype("int32")
|
||||||
|
assert test_config("unk") == np.dtype("int32")
|
||||||
|
# test exception, data_type isn't the correct type
|
||||||
|
assert "tldr is not of type (<class 'mindspore._c_expression.typing.Type'>,)" in test_config("unk", "tldr")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_from_dict_exception()
|
test_from_dict_exception()
|
||||||
test_from_list_tutorial()
|
test_from_list_tutorial()
|
||||||
|
@ -144,3 +175,4 @@ if __name__ == '__main__':
|
||||||
test_from_dict_tutorial()
|
test_from_dict_tutorial()
|
||||||
test_from_list()
|
test_from_list()
|
||||||
test_from_file()
|
test_from_file()
|
||||||
|
test_lookup_cast_type()
|
||||||
|
|
Loading…
Reference in New Issue