!7866 [MD] C++ API support build_sentence_piece_vocab_node & sentence_piece_tokenizer

Merge pull request !7866 from luoyang/c-api-pyfunc
This commit is contained in:
mindspore-ci-bot 2020-10-29 14:26:13 +08:00 committed by Gitee
commit 8f44074796
16 changed files with 549 additions and 27 deletions

View File

@ -67,6 +67,7 @@
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#endif
@ -553,6 +554,35 @@ std::shared_ptr<BucketBatchByLengthNode> Dataset::BucketBatchByLength(
return ds;
}
// Function to create a SentencePieceVocab from dataset
std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params) {
auto vocab = std::make_shared<SentencePieceVocab>();
auto ds = std::make_shared<BuildSentenceVocabNode>(shared_from_this(), vocab, col_names, vocab_size,
character_coverage, model_type, params);
// Validate input params
if (!ds->ValidateParams()) {
return nullptr;
}
// Run tree here to start building vocab
std::shared_ptr<Iterator> iter = ds->CreateIterator();
if (iter == nullptr) {
MS_LOG(ERROR) << "Fail to run iterator in BuildSentencePieceVocab.";
return nullptr;
}
// Finish building vocab by triggering GetNextRow
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
if (!iter->GetNextRow(&row)) {
return nullptr;
}
return vocab;
}
// Function to create a Vocab from dataset
std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &columns,
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,

View File

@ -14,8 +14,11 @@
* limitations under the License.
*/
#include <unistd.h>
#include "minddata/dataset/include/text.h"
#include "minddata/dataset/text/kernels/lookup_op.h"
#include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h"
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace dataset {
@ -31,10 +34,21 @@ std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, con
const DataType &data_type) {
auto op = std::make_shared<LookupOperation>(vocab, unknown_token, data_type);
if (!op->ValidateParams()) {
return nullptr;
}
return op;
return op->ValidateParams() ? op : nullptr;
}
std::shared_ptr<SentencePieceTokenizerOperation> SentencePieceTokenizer(
const std::shared_ptr<SentencePieceVocab> &vocab, SPieceTokenizerOutType out_type) {
auto op = std::make_shared<SentencePieceTokenizerOperation>(vocab, out_type);
return op->ValidateParams() ? op : nullptr;
}
std::shared_ptr<SentencePieceTokenizerOperation> SentencePieceTokenizer(const std::string &vocab_path,
SPieceTokenizerOutType out_type) {
auto op = std::make_shared<SentencePieceTokenizerOperation>(vocab_path, out_type);
return op->ValidateParams() ? op : nullptr;
}
/* ####################################### Validator Functions ############################################ */
@ -70,6 +84,51 @@ std::shared_ptr<TensorOp> LookupOperation::Build() {
return tensor_op;
}
// SentencePieceTokenizerOperation
SentencePieceTokenizerOperation::SentencePieceTokenizerOperation(const std::shared_ptr<SentencePieceVocab> &vocab,
SPieceTokenizerOutType out_type)
: vocab_(vocab), vocab_path_(std::string()), load_type_(SPieceTokenizerLoadType::kModel), out_type_(out_type) {}
SentencePieceTokenizerOperation::SentencePieceTokenizerOperation(const std::string &vocab_path,
SPieceTokenizerOutType out_type)
: vocab_(nullptr), vocab_path_(vocab_path), load_type_(SPieceTokenizerLoadType::kFile), out_type_(out_type) {}
Status SentencePieceTokenizerOperation::ValidateParams() {
if (load_type_ == SPieceTokenizerLoadType::kModel) {
if (vocab_ == nullptr) {
std::string err_msg = "SentencePieceTokenizer: vocab object type is incorrect or null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
} else {
Path vocab_file(vocab_path_);
if (!vocab_file.Exists() || vocab_file.IsDirectory()) {
std::string err_msg = "SentencePieceTokenizer : vocab file: [" + vocab_path_ + "] is invalid or does not exist.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (access(vocab_file.toString().c_str(), R_OK) == -1) {
std::string err_msg = "SentencePieceTokenizer : no access to specified dataset file: " + vocab_path_;
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
return Status::OK();
}
std::shared_ptr<TensorOp> SentencePieceTokenizerOperation::Build() {
std::shared_ptr<SentencePieceTokenizerOp> tensor_op;
if (load_type_ == SPieceTokenizerLoadType::kModel) {
tensor_op = std::make_shared<SentencePieceTokenizerOp>(vocab_, load_type_, out_type_);
} else {
Path vocab_file(vocab_path_);
std::string model_path = vocab_file.ParentPath();
std::string model_filename = vocab_file.Basename();
tensor_op = std::make_shared<SentencePieceTokenizerOp>(model_path, model_filename, load_type_, out_type_);
}
return tensor_op;
}
} // namespace text
} // namespace api
} // namespace dataset

View File

@ -724,7 +724,7 @@ Status NormalizeOperation::ValidateParams() {
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (mean_[i] < 0.0f || mean_[i] > 255.0f || CmpFloat(mean_[i], 0.0f)) {
std::string err_msg = "Normalize: mean vector has incorrect value: " + std::to_string(std_[i]);
std::string err_msg = "Normalize: mean vector has incorrect value: " + std::to_string(mean_[i]);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

View File

@ -119,6 +119,11 @@ Status SaveToDisk::ValidateParams() {
MS_LOG(ERROR) << err;
RETURN_STATUS_SYNTAX_ERROR(err);
}
if (access(dir.ParentPath().c_str(), R_OK) == -1) {
std::string err_msg = "CreateSaver failed, no access to specified dataset path: " + dataset_path_;
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (num_files_ <= 0 || num_files_ > 1000) {
std::string err = "CreateSaver failed, num_files must between 1 and 1000, but got " + std::to_string(num_files_);
MS_LOG(ERROR) << err;

View File

@ -76,7 +76,7 @@ Status BuildSentencePieceVocabOp::SentenceThread() {
} else {
auto itr = column_name_id_map_.find(col_names_[0]);
CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(),
"Invalid parameter, column name: " + col_names_[0] + "does not exist.");
"Invalid parameter, column name: " + col_names_[0] + " does not exist.");
col_id_ = itr->second;
}
std::unique_ptr<DatasetSentenceIterator> sentence_iter = std::make_unique<DatasetSentenceIterator>(this);

View File

@ -48,6 +48,17 @@ Status WeightedRandomSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0,
"Invalid parameter, samples_per_buffer must be greater than 0, but got " +
std::to_string(samples_per_buffer_) + ".\n");
CHECK_FAIL_RETURN_UNEXPECTED(weights_.size() != 0, "Invalid parameter, weights size must not be 0.\n");
int32_t zero_elem = 0;
for (auto &elem : weights_) {
CHECK_FAIL_RETURN_UNEXPECTED(elem >= 0.0, "Invalid parameter, weights must not contain negative number, but got " +
std::to_string(elem) + ".\n");
if (elem == 0.0) zero_elem++;
}
CHECK_FAIL_RETURN_UNEXPECTED(zero_elem != weights_.size(),
"Invalid parameter, elements of weights must not be all zero.\n");
if (weights_.size() > static_cast<size_t>(num_rows_)) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, size of sample weights must be less than or equal to num of data, "

View File

@ -5,6 +5,7 @@ add_subdirectory(source)
set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
batch_node.cc
bucket_batch_by_length_node.cc
build_sentence_piece_vocab_node.cc
build_vocab_node.cc
concat_node.cc
map_node.cc

View File

@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<Dataset> child,
std::shared_ptr<SentencePieceVocab> vocab,
const std::vector<std::string> &col_names, uint32_t vocab_size,
float character_coverage, SentencePieceModel model_type,
const std::unordered_map<std::string, std::string> &params)
: vocab_(vocab),
col_names_(col_names),
vocab_size_(vocab_size),
character_coverage_(character_coverage),
model_type_(model_type),
params_(params) {
this->children.push_back(child);
}
// Function to build BuildSentenceVocabNode
std::vector<std::shared_ptr<DatasetOp>> BuildSentenceVocabNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
std::shared_ptr<BuildSentencePieceVocabOp> build_sentence_piece_vocab_op;
build_sentence_piece_vocab_op = std::make_shared<BuildSentencePieceVocabOp>(
vocab_, col_names_, vocab_size_, character_coverage_, model_type_, params_, connector_que_size_);
node_ops.push_back(build_sentence_piece_vocab_op);
return node_ops;
}
Status BuildSentenceVocabNode::ValidateParams() {
if (vocab_ == nullptr) {
std::string err_msg = "BuildSentenceVocabNode: vocab is null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (vocab_size_ <= 0) {
std::string err_msg =
"BuildSentenceVocabNode: vocab_size should be positive, but got: " + std::to_string(vocab_size_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (character_coverage_ < 0.98f || character_coverage_ > 1.0f) {
std::string err_msg = "BuildSentenceVocabNode: character_coverage should to be between 0.98 and 1.0, but got " +
std::to_string(character_coverage_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_SENTENCE_PIECE_VOCAB_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_SENTENCE_PIECE_VOCAB_NODE_H_
#include <memory>
#include <string>
#include <utility>
#include <unordered_map>
#include <vector>
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
namespace api {
class BuildSentenceVocabNode : public Dataset {
public:
/// \brief Constructor
BuildSentenceVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<SentencePieceVocab> vocab,
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params);
/// \brief Destructor
~BuildSentenceVocabNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::shared_ptr<SentencePieceVocab> vocab_;
std::vector<std::string> col_names_;
uint32_t vocab_size_;
float character_coverage_;
SentencePieceModel model_type_;
std::unordered_map<std::string, std::string> params_;
};
} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_SENTENCE_PIECE_VOCAB_NODE_H_

View File

@ -22,6 +22,7 @@
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
@ -37,6 +38,7 @@
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/path.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/text/sentence_piece_vocab.h"
#include "minddata/dataset/text/vocab.h"
#endif
@ -86,7 +88,6 @@ class VOCNode;
// Dataset Op classes (in alphabetical order)
#ifndef ENABLE_ANDROID
class BucketBatchByLengthNode;
class BuildVocabNode;
#endif
class ConcatNode;
class MapNode;
@ -640,7 +641,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \brief Function to transfer data through a device.
/// \notes If device is Ascend, features of data will be transferred one by one. The limitation
/// of data transmission per time is 256M.
/// of data transmission per time is 256M.
/// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=True).
/// \return Returns true if no error encountered else false.
bool DeviceQueue(bool send_epoch_end = true);
@ -663,7 +664,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \brief Function to create a BatchNode
/// \notes Combines batch_size number of consecutive rows into batches
/// \param[in] batch_size Path to the root directory that contains the dataset
/// \param[in] batch_size The number of rows each batch is created with
/// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete
/// batch. If true, and if there are less than batch_size rows
/// available to make the last batch, then those rows will
@ -673,7 +674,8 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
#ifndef ENABLE_ANDROID
/// \brief Function to create a BucketBatchByLengthNode
/// \notes Combines batch_size number of consecutive rows into batches
/// \notes Bucket elements according to their lengths. Each bucket will be padded and batched when
/// they are full.
/// \param[in] column_names Columns passed to element_length_function
/// \param[in] bucket_boundaries A list consisting of the upper boundaries of the buckets.
/// Must be strictly increasing. If there are n boundaries, n+1 buckets are created: One bucket for
@ -681,10 +683,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// 0<i<n, and one bucket for [bucket_boundaries[n-1], inf).
/// \param[in] bucket_batch_sizes A list consisting of the batch sizes for each bucket.
/// Must contain elements equal to the size of bucket_boundaries + 1.
/// \param[in] element_length_function A function pointer that takes in TensorRow and outputs a TensorRow. The
/// output
/// must contain a single tensor containing a single int32_t. If no value is provided, then size of column_names
/// must be 1, and the size of the first dimension of that column will be taken as the length (default=nullptr)
/// \param[in] element_length_function A function pointer that takes in TensorRow and outputs a TensorRow.
/// The output must contain a single tensor containing a single int32_t. If no value is provided,
/// then size of column_names must be 1, and the size of the first dimension of that column will be taken
/// as the length (default=nullptr)
/// \param[in] pad_info Represents how to batch each column. The key corresponds to the column name, the value must
/// be a tuple of 2 elements. The first element corresponds to the shape to pad to, and the second element
/// corresponds to the value to pad with. If a column is not specified, then that column will be padded to the
@ -692,8 +694,8 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// padded to the longest in the current batch, unless if pad_to_bucket_boundary is true. If no padding is
/// wanted, set pad_info to None (default=empty dictionary).
/// \param[in] pad_to_bucket_boundary If true, will pad each unspecified dimension in pad_info to the
/// bucket_boundary
/// minus 1. If there are any elements that fall into the last bucket, an error will occur (default=false).
/// bucket_boundary minus 1. If there are any elements that fall into the last bucket,
/// an error will occur (default=false).
/// \param[in] drop_remainder If true, will drop the last batch for each bucket if it is not a full batch
/// (default=false).
/// \return Shared pointer to the current BucketBatchByLengthNode
@ -704,6 +706,20 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
bool pad_to_bucket_boundary = false, bool drop_remainder = false);
/// \brief Function to create a SentencePieceVocab from source dataset
/// \notes Build a SentencePieceVocab from a dataset.
/// \param[in] col_names Column names to get words from. It can be a vector of column names
/// \param[in] vocab_size Vocabulary size. The type is uint32
/// \param[in] character_coverage Percentage of characters covered by the model, must be between
/// 0.98 and 1.0 Good defaults are: 0.9995 for languages with rich character sets like
/// Japanese or Chinese character sets, and 1.0 for other languages with small character sets.
/// \param[in] model_type Model type. Choose from unigram (default), bpe, char, or word.
/// The input sentence must be pretokenized when using word type.
/// \param[in] params A vector contains more option parameters of sentencepiece library
std::shared_ptr<SentencePieceVocab> BuildSentencePieceVocab(
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params);
/// \brief Function to create a Vocab from source dataset
/// \notes Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab
/// which contains top_k most frequent words (if top_k is specified)

View File

@ -21,11 +21,14 @@
#include <string>
#include <vector>
#include "mindspore/ccsrc/minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/text/vocab.h"
#include "minddata/dataset/util/status.h"
#include "mindspore/ccsrc/minddata/dataset/core/data_type.h"
#include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h"
#include "minddata/dataset/text/sentence_piece_vocab.h"
#include "minddata/dataset/text/vocab.h"
namespace mindspore {
namespace dataset {
@ -36,6 +39,7 @@ namespace text {
// Text Op classes (in alphabetical order)
class LookupOperation;
class SentencePieceTokenizerOperation;
/// \brief Lookup operator that looks up a word to an id.
/// \param[in] vocab a Vocab object.
@ -46,6 +50,20 @@ class LookupOperation;
std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
const mindspore::dataset::DataType &data_type = DataType("int32"));
/// \brief Tokenize scalar token or 1-D tokens to tokens by sentencepiece.
/// \param[in] vocab a SentencePieceVocab object.
/// \param[in] out_type The type of output.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<SentencePieceTokenizerOperation> SentencePieceTokenizer(
const std::shared_ptr<SentencePieceVocab> &vocab, mindspore::dataset::SPieceTokenizerOutType out_type);
/// \brief Tokenize scalar token or 1-D tokens to tokens by sentencepiece.
/// \param[in] vocab_path vocab model file path.
/// \param[in] out_type The type of output.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<SentencePieceTokenizerOperation> SentencePieceTokenizer(
const std::string &vocab_path, mindspore::dataset::SPieceTokenizerOutType out_type);
/* ####################################### Derived TensorOperation classes ################################# */
class LookupOperation : public TensorOperation {
@ -65,6 +83,25 @@ class LookupOperation : public TensorOperation {
int32_t default_id_;
DataType data_type_;
};
class SentencePieceTokenizerOperation : public TensorOperation {
public:
SentencePieceTokenizerOperation(const std::shared_ptr<SentencePieceVocab> &vocab, SPieceTokenizerOutType out_type);
SentencePieceTokenizerOperation(const std::string &vocab_path, SPieceTokenizerOutType out_type);
~SentencePieceTokenizerOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
private:
std::shared_ptr<SentencePieceVocab> vocab_;
std::string vocab_path_;
SPieceTokenizerLoadType load_type_;
SPieceTokenizerOutType out_type_;
};
} // namespace text
} // namespace api
} // namespace dataset

View File

@ -591,15 +591,6 @@ class WeightedRandomSampler(BuiltinSampler):
if not isinstance(weights, list):
weights = [weights]
if weights == []:
raise ValueError("weights size should not be 0")
if list(filter(lambda x: x < 0, weights)):
raise ValueError("weights should not contain negative numbers")
if list(filter(lambda x: x == 0, weights)) == weights:
raise ValueError("elements of weights should not be all zero")
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "

View File

@ -186,6 +186,7 @@ if (BUILD_MINDDATA STREQUAL "full")
list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SRC_FILES
"${MINDDATA_DIR}/engine/ir/datasetops/bucket_batch_by_length_node.cc"
"${MINDDATA_DIR}/engine/ir/datasetops/build_sentence_piece_vocab_node.cc"
"${MINDDATA_DIR}/engine/ir/datasetops/build_vocab_node.cc"
"${MINDDATA_DIR}/engine/ir/datasetops/sync_wait_node.cc"
)

View File

@ -112,12 +112,15 @@ SET(DE_UT_SRCS
c_api_dataset_config_test.cc
c_api_dataset_csv_test.cc
c_api_dataset_manifest_test.cc
c_api_dataset_minddata_test.cc
c_api_dataset_randomdata_test.cc
c_api_dataset_save.cc
c_api_dataset_textfile_test.cc
c_api_dataset_tfrecord_test.cc
c_api_dataset_voc_test.cc
c_api_datasets_test.cc
c_api_dataset_iterator_test.cc
c_api_text_sentence_piece_vocab_test.cc
c_api_text_vocab_test.cc
c_api_cache_test.cc
tensor_op_fusion_pass_test.cc

View File

@ -0,0 +1,224 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <vector>
#include <string>
#include "common/common.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/status.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/include/text.h"
// IR non-leaf nodes
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
// IR leaf nodes
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;
using mindspore::dataset::ShuffleMode;
using mindspore::dataset::SentencePieceModel;
using mindspore::dataset::SentencePieceVocab;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestPipeline, TestSentencePieceVocabSuccess1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSentencePieceVocabSuccess1 plus sentencepiece tokenizer.";
// Create a TextFile dataset
std::string vocab_file = datasets_root_path_ + "/test_sentencepiece/botchan.txt";
std::shared_ptr<Dataset> ds_vocab = TextFile({vocab_file}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds_vocab, nullptr);
// Create vocab from dataset
std::shared_ptr<SentencePieceVocab> vocab =
ds_vocab->BuildSentencePieceVocab({}, 5000, 0.9995, SentencePieceModel::kUnigram, {});
EXPECT_NE(vocab, nullptr);
// Create a TextFile dataset
std::string data_file = datasets_root_path_ + "/testTokenizerData/sentencepiece_tokenizer.txt";
std::shared_ptr<Dataset> ds = TextFile({data_file}, 0, ShuffleMode::kFalse);
// Create SentencePieceTokenizer operation from vocab object
std::shared_ptr<TensorOperation> sentencepiece_tokenizer =
text::SentencePieceTokenizer(vocab, mindspore::dataset::SPieceTokenizerOutType::kString);
EXPECT_NE(sentencepiece_tokenizer, nullptr);
// Create Map operation on ds
ds = ds->Map({sentencepiece_tokenizer}, {"text"});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Expected result after tokenization
std::vector<std::string> expected = {"▁I", "▁sa", "w", "▁a", "▁girl", "▁with", "▁a", "▁te", "les", "co", "pe", "."};
uint64_t i = 0;
while (row.size() != 0) {
auto txt = row["text"];
MS_LOG(INFO) << *txt;
std::shared_ptr<Tensor> expected_tensor;
Tensor::CreateFromVector(expected, &expected_tensor);
EXPECT_EQ(*txt, *expected_tensor);
iter->GetNextRow(&row);
i++;
}
EXPECT_EQ(i, 1);
}
TEST_F(MindDataTestPipeline, TestSentencePieceVocabSuccess2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSentencePieceVocabSuccess2 plus sentencepiece tokenizer.";
// Create a TextFile dataset
std::string vocab_file = datasets_root_path_ + "/test_sentencepiece/botchan.txt";
std::shared_ptr<Dataset> ds_vocab = TextFile({vocab_file}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds_vocab, nullptr);
// Create vocab from dataset
std::shared_ptr<SentencePieceVocab> vocab =
ds_vocab->BuildSentencePieceVocab({}, 5000, 0.9995, SentencePieceModel::kUnigram, {});
EXPECT_NE(vocab, nullptr);
// Save vocab model to local
vocab->SaveModel(&vocab, datasets_root_path_ + "/test_sentencepiece", "m.model");
// Create a TextFile dataset
std::string data_file = datasets_root_path_ + "/testTokenizerData/sentencepiece_tokenizer.txt";
std::shared_ptr<Dataset> ds = TextFile({data_file}, 0, ShuffleMode::kFalse);
// Create SentencePieceTokenizer operation from local vocab model
std::string vocab_model = datasets_root_path_ + "/test_sentencepiece/m.model";
std::shared_ptr<TensorOperation> sentencepiece_tokenizer =
text::SentencePieceTokenizer(vocab_model, mindspore::dataset::SPieceTokenizerOutType::kString);
EXPECT_NE(sentencepiece_tokenizer, nullptr);
// Create Map operation on ds
ds = ds->Map({sentencepiece_tokenizer}, {"text"});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Expected result after tokenization
std::vector<std::string> expected = {"▁I", "▁sa", "w", "▁a", "▁girl", "▁with", "▁a", "▁te", "les", "co", "pe", "."};
uint64_t i = 0;
while (row.size() != 0) {
auto txt = row["text"];
MS_LOG(INFO) << *txt;
std::shared_ptr<Tensor> expected_tensor;
Tensor::CreateFromVector(expected, &expected_tensor);
EXPECT_EQ(*txt, *expected_tensor);
iter->GetNextRow(&row);
i++;
}
EXPECT_EQ(i, 1);
}
TEST_F(MindDataTestPipeline, TestSentencePieceVocabFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSentencePieceVocabFail1 with incorrect parameter.";
// Create a TextFile dataset
std::string vocab_file = datasets_root_path_ + "/test_sentencepiece/botchan.txt";
std::shared_ptr<Dataset> ds_vocab = TextFile({vocab_file}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds_vocab, nullptr);
// vocab_size can not less than or equal to 0
std::shared_ptr<SentencePieceVocab> vocab1 =
ds_vocab->BuildSentencePieceVocab({}, 0, 0.9995, SentencePieceModel::kUnigram, {});
EXPECT_EQ(vocab1, nullptr);
// character_coverage should to be between 0.98 and 1.0
std::shared_ptr<SentencePieceVocab> vocab2 =
ds_vocab->BuildSentencePieceVocab({}, 1, 0.979, SentencePieceModel::kUnigram, {});
EXPECT_EQ(vocab2, nullptr);
// character_coverage should to be between 0.98 and 1.0
std::shared_ptr<SentencePieceVocab> vocab3 =
ds_vocab->BuildSentencePieceVocab({}, 1, 1.01, SentencePieceModel::kUnigram, {});
EXPECT_EQ(vocab3, nullptr);
// column name does not exist
std::shared_ptr<SentencePieceVocab> vocab4 =
ds_vocab->BuildSentencePieceVocab({"image"}, 2, 0.98, SentencePieceModel::kUnigram, {});
EXPECT_EQ(vocab4, nullptr);
}
TEST_F(MindDataTestPipeline, TestSentencePieceTokenizerFail1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSentencePieceTokenizerFail with incorrect parameter.";
// Create SentencePieceTokenizer operation from local vocab model
std::string vocab_model1 = "";
std::shared_ptr<TensorOperation> sentencepiece_tokenizer1 =
text::SentencePieceTokenizer(vocab_model1, mindspore::dataset::SPieceTokenizerOutType::kString);
EXPECT_EQ(sentencepiece_tokenizer1, nullptr);
// Create SentencePieceTokenizer operation from local vocab model
std::string vocab_model2 = "m.model";
std::shared_ptr<TensorOperation> sentencepiece_tokenizer2 =
text::SentencePieceTokenizer(vocab_model2, mindspore::dataset::SPieceTokenizerOutType::kString);
EXPECT_EQ(sentencepiece_tokenizer2, nullptr);
// Create SentencePieceTokenizer operation from vocab object
std::shared_ptr<SentencePieceVocab> vocab_model3 = nullptr;
std::shared_ptr<TensorOperation> sentencepiece_tokenizer3 =
text::SentencePieceTokenizer(vocab_model3, mindspore::dataset::SPieceTokenizerOutType::kString);
EXPECT_EQ(sentencepiece_tokenizer3, nullptr);
}
TEST_F(MindDataTestPipeline, TestSentencePieceTokenizerFail2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSentencePieceTokenizerFail with invalid SentencePieceVocab object.";
// Create a TextFile dataset
std::string data_file = datasets_root_path_ + "/testTokenizerData/sentencepiece_tokenizer.txt";
std::shared_ptr<Dataset> ds = TextFile({data_file}, 0, ShuffleMode::kFalse);
// Create SentencePieceTokenizer operation from vocab object
std::shared_ptr<SentencePieceVocab> vocab_model4 = std::make_shared<SentencePieceVocab>();
std::shared_ptr<TensorOperation> sentencepiece_tokenizer4 =
text::SentencePieceTokenizer(vocab_model4, mindspore::dataset::SPieceTokenizerOutType::kString);
EXPECT_NE(sentencepiece_tokenizer4, nullptr);
// Create Map operation on ds
ds = ds->Map({sentencepiece_tokenizer4}, {"text"});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
EXPECT_EQ(iter->GetNextRow(&row), false);
}