!17374 files_need_cleanup report fixes to master

From: @hfarahat
Reviewed-by: @pandoublefeng,@robingrosman
Signed-off-by: @robingrosman
This commit is contained in:
mindspore-ci-bot 2021-06-01 05:21:18 +08:00 committed by Gitee
commit 4338e4abac
44 changed files with 252 additions and 356 deletions

View File

@ -43,11 +43,10 @@ Status DeviceTensor::CreateEmpty(const TensorShape &shape, const DataType &type,
CHECK_FAIL_RETURN_UNEXPECTED(type.IsNumeric(), "Number of elements is not 0. The type should be numeric.");
int64_t byte_size = (*out)->SizeInBytes();
int64_t bytes = (*out)->SizeInBytes();
// Don't allocate if we have a tensor with no elements.
if (byte_size != 0) {
RETURN_IF_NOT_OK((*out)->AllocateBuffer(byte_size));
if (bytes != 0) {
RETURN_IF_NOT_OK((*out)->AllocateBuffer(bytes));
}
return Status::OK();
}

View File

@ -21,7 +21,6 @@
#include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h"
namespace mindspore::dataset {
PullBasedIteratorConsumer::PullBasedIteratorConsumer() { tree_adapter_lite_ = std::make_unique<TreeAdapterLite>(); }
Status PullBasedIteratorConsumer::Init(std::shared_ptr<DatasetNode> root) {

View File

@ -19,7 +19,6 @@
#include "minddata/dataset/engine/consumers/python_tree_consumer.h"
namespace mindspore::dataset {
Status PythonIteratorConsumer::GetNextAsList(py::list *out) {
std::vector<TensorPtr> row;
{

View File

@ -44,18 +44,7 @@ void EpochCtrlOp::Print(std::ostream &out, bool show_all) const {
out << " [epochs: " << num_repeats_ << "]\n";
} else {
// Call the super class for displaying any common detailed info
PipelineOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_
<< "\nLeaf Nodes in execution path:";
if (!eoe_ops_.empty()) {
for (size_t i = 0; i < eoe_ops_.size(); i++) {
out << "\n Operator: " << eoe_ops_[i]->id();
}
} else {
out << " None.";
}
out << "\n\n";
RepeatOp::Print(out, show_all);
}
}

View File

@ -57,8 +57,7 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info
PipelineOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_
<< "\nLeaf Nodes in execution path:";
out << "\nCurrent count: " << repeat_count_ << "\nMax count: " << num_repeats_ << "\nLeaf Nodes in execution path:";
if (!eoe_ops_.empty()) {
for (size_t i = 0; i < eoe_ops_.size(); i++) {
out << "\n Operator: " << eoe_ops_[i]->id();

View File

@ -199,21 +199,12 @@ void DistributedSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const
Status DistributedSamplerRT::to_json(nlohmann::json *out_json) {
nlohmann::json args;
RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
args["sampler_name"] = "DistributedSampler";
args["num_shards"] = num_devices_;
args["shard_id"] = device_id_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
args["offset"] = offset_;
if (this->HasChildSampler()) {
std::vector<nlohmann::json> children_args;
for (auto child : child_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

View File

@ -129,19 +129,10 @@ void PKSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
Status PKSamplerRT::to_json(nlohmann::json *out_json) {
nlohmann::json args;
RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
args["sampler_name"] = "PKSampler";
args["num_val"] = samples_per_class_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
if (this->HasChildSampler()) {
std::vector<nlohmann::json> children_args;
for (auto child : child_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

View File

@ -128,19 +128,11 @@ void RandomSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
Status RandomSamplerRT::to_json(nlohmann::json *out_json) {
nlohmann::json args;
RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
args["sampler_name"] = "RandomSampler";
args["replacement"] = replacement_;
args["num_samples"] = num_samples_;
args["reshuffle_each_epoch"] = reshuffle_each_epoch_;
if (this->HasChildSampler()) {
std::vector<nlohmann::json> children_args;
for (auto child : child_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

View File

@ -183,6 +183,21 @@ Status SamplerRT::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
RETURN_IF_NOT_OK(sample_ids->GetItemAt<int64_t>(out_associated_id, {id}));
return Status::OK();
}
Status SamplerRT::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_samples"] = num_samples_;
if (this->HasChildSampler()) {
std::vector<nlohmann::json> children_args;
for (const auto &child : child_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -156,7 +156,7 @@ class SamplerRT {
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
virtual Status to_json(nlohmann::json *out_json);
protected:
// Number of rows of data from the place this sampler is sampling from. If this sampler

View File

@ -133,18 +133,9 @@ void SequentialSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
Status SequentialSamplerRT::to_json(nlohmann::json *out_json) {
nlohmann::json args;
RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
args["sampler_name"] = "SequentialSampler";
args["start_index"] = start_index_;
args["num_samples"] = num_samples_;
if (this->HasChildSampler()) {
std::vector<nlohmann::json> children_args;
for (auto child : child_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

View File

@ -116,18 +116,10 @@ void SubsetSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
Status SubsetSamplerRT::to_json(nlohmann::json *out_json) {
nlohmann::json args;
RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
args["sampler_name"] = "SubsetSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (this->HasChildSampler()) {
std::vector<nlohmann::json> children_args;
for (auto child : child_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

View File

@ -189,19 +189,10 @@ void WeightedRandomSamplerRT::SamplerPrint(std::ostream &out, bool show_all) con
Status WeightedRandomSamplerRT::to_json(nlohmann::json *out_json) {
nlohmann::json args;
RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
args["sampler_name"] = "WeightedRandomSampler";
args["weights"] = weights_;
args["num_samples"] = num_samples_;
args["replacement"] = replacement_;
if (this->HasChildSampler()) {
std::vector<nlohmann::json> children_args;
for (auto child : child_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}

View File

@ -126,7 +126,7 @@ Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_s
}
if (shard_id < 0 || shard_id >= num_shards) {
// num_shards;
// num_shards
std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) +
", num_shards: " + std::to_string(num_shards);
MS_LOG(ERROR) << err_msg;
@ -150,8 +150,8 @@ Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared
Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
const std::unordered_set<std::string> &valid_strings) {
if (valid_strings.find(str) == valid_strings.end()) {
std::string mode;
mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode,
std::string init;
std::string mode = std::accumulate(valid_strings.begin(), valid_strings.end(), init,
[](std::string a, std::string b) { return std::move(a) + " " + std::move(b); });
std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]";
MS_LOG(ERROR) << err_msg;

View File

@ -144,8 +144,7 @@ Status GeneratorNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &s
return Status::OK();
} else {
int64_t sample_size;
int64_t num_rows;
num_rows = source_len_;
int64_t num_rows = source_len_;
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
if (sampler_) RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_ ? sampler_rt->CalculateNumSamples(num_rows) : num_rows;

View File

@ -18,7 +18,6 @@
#include "pybind11/pybind11.h"
namespace mindspore::dataset {
Status PythonRuntimeContext::Terminate() {
MS_LOG(INFO) << "Terminating a PythonRuntime";
if (tree_consumer_ != nullptr) {

View File

@ -18,7 +18,6 @@
#include <memory>
#include <utility>
namespace mindspore::dataset {
void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) {
tree_consumer_ = std::move(tree_consumer);
}

View File

@ -548,7 +548,7 @@ Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_BOOL), output));
std::unique_ptr<TypeCastOp> value_cast_op(new TypeCastOp(input->type()));
std::unique_ptr<TypeCastOp> value_cast_op = std::make_unique<TypeCastOp>(input->type());
std::shared_ptr<Tensor> casted_value;
if (input->type().IsNumeric()) {
RETURN_IF_NOT_OK(value_cast_op->Compute(value, &casted_value));

View File

@ -51,7 +51,7 @@ Status BoundingBox::ValidateBoundingBoxes(const TensorRow &image_and_bbox) {
"BoundingBox: bounding boxes should have to be two-dimensional matrix at least.");
}
uint32_t num_of_features = image_and_bbox[1]->shape()[1];
if (num_of_features < 4) {
if (num_of_features < kNumOfCols) {
return Status(StatusCode::kMDBoundingBoxInvalidShape, __LINE__, __FILE__,
"BoundingBox: bounding boxes should be have at least 4 features.");
}

View File

@ -25,7 +25,8 @@ Status ComputeUpperAndLowerPercentiles(std::vector<int32_t> *hist, int32_t hi_p,
int32_t *lo) {
try {
int32_t n = std::accumulate(hist->begin(), hist->end(), 0);
int32_t cut = static_cast<int32_t>((low_p / 100.0) * n);
constexpr float kMaxPerc = 100.0;
int32_t cut = static_cast<int32_t>((low_p / kMaxPerc) * n);
for (int32_t lb = 0; lb < hist->size() + 1 && cut > 0; lb++) {
if (cut > (*hist)[lb]) {
cut -= (*hist)[lb];
@ -35,7 +36,7 @@ Status ComputeUpperAndLowerPercentiles(std::vector<int32_t> *hist, int32_t hi_p,
cut = 0;
}
}
cut = static_cast<int32_t>((hi_p / 100.0) * n);
cut = static_cast<int32_t>((hi_p / kMaxPerc) * n);
for (int32_t ub = hist->size() - 1; ub >= 0 && cut > 0; ub--) {
if (cut > (*hist)[ub]) {
cut -= (*hist)[ub];
@ -52,9 +53,8 @@ Status ComputeUpperAndLowerPercentiles(std::vector<int32_t> *hist, int32_t hi_p,
for (; (*hi) >= 0 && !(*hist)[*hi]; (*hi)--) {
}
} catch (const std::exception &e) {
const char *err_msg = e.what();
std::string err_message = "AutoContrast: ComputeUpperAndLowerPercentiles failed: ";
err_message += err_msg;
err_message += e.what();
RETURN_STATUS_UNEXPECTED(err_message);
}
return Status::OK();
@ -70,9 +70,8 @@ Status GenerateRealNumber(float_t a, float_t b, std::mt19937 *rnd, float_t *resu
std::uniform_real_distribution<float_t> distribution{a, b};
*result = distribution(*rnd);
} catch (const std::exception &e) {
const char *err_msg = e.what();
std::string err_message = "RandomAffine: GenerateRealNumber failed: ";
err_message += err_msg;
err_message += e.what();
RETURN_STATUS_UNEXPECTED(err_message);
}
return Status::OK();

View File

@ -15,6 +15,7 @@ add_library(text-kernels OBJECT
data_utils.cc
lookup_op.cc
jieba_tokenizer_op.cc
tokenizer_op.cc
unicode_char_tokenizer_op.cc
ngram_op.cc
sliding_window_op.cc

View File

@ -31,7 +31,6 @@ const bool BasicTokenizerOp::kDefLowerCase = false;
const bool BasicTokenizerOp::kDefKeepWhitespace = false;
const NormalizeForm BasicTokenizerOp::kDefNormalizationForm = NormalizeForm::kNone;
const bool BasicTokenizerOp::kDefPreserveUnusedToken = true;
const bool BasicTokenizerOp::kDefWithOffsets = false;
const char BasicTokenizerOp::kCommonPattern[] =
"[!-/]"
"|[:-@]"
@ -52,10 +51,10 @@ const std::unordered_set<std::string> BasicTokenizerOp::kUnusedWords{"[CLS]", "[
BasicTokenizerOp::BasicTokenizerOp(const bool &lower_case, const bool &keep_whitespace,
const NormalizeForm &normalization_form, const bool &preserve_unused_token,
const bool &with_offsets)
: lower_case_(lower_case),
: TokenizerOp(with_offsets),
lower_case_(lower_case),
keep_whitespace_(keep_whitespace),
preserve_unused_token_(preserve_unused_token),
with_offsets_(with_offsets),
case_fold_(std::make_unique<CaseFoldOp>()),
nfd_normalize_(std::make_unique<NormalizeUTF8Op>(NormalizeForm::kNfd)),
normalization_form_(normalization_form),

View File

@ -25,18 +25,18 @@
#include "minddata/dataset/text/kernels/normalize_utf8_op.h"
#include "minddata/dataset/text/kernels/regex_replace_op.h"
#include "minddata/dataset/text/kernels/regex_tokenizer_op.h"
#include "minddata/dataset/text/kernels/tokenizer_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class BasicTokenizerOp : public TensorOp {
class BasicTokenizerOp : public TokenizerOp {
public:
static const bool kDefLowerCase;
static const bool kDefKeepWhitespace;
static const NormalizeForm kDefNormalizationForm;
static const bool kDefPreserveUnusedToken;
static const bool kDefWithOffsets;
explicit BasicTokenizerOp(const bool &lower_case = kDefLowerCase, const bool &keep_whitespace = kDefKeepWhitespace,
const NormalizeForm &normalization_form = kDefNormalizationForm,
@ -58,7 +58,6 @@ class BasicTokenizerOp : public TensorOp {
static const char kCommonPattern[];
static const char kUnusedPattern[];
static const std::unordered_set<std::string> kUnusedWords;
bool with_offsets_;
bool lower_case_;
bool keep_whitespace_;
NormalizeForm normalization_form_;

View File

@ -21,6 +21,8 @@
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/text/kernels/basic_tokenizer_op.h"
#include "minddata/dataset/text/kernels/tokenizer_op.h"
#include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h"
#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
#include "minddata/dataset/util/status.h"
@ -36,7 +38,7 @@ class BertTokenizerOp : public TensorOp {
const bool &keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace,
const NormalizeForm &normalization_form = BasicTokenizerOp::kDefNormalizationForm,
const bool &preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken,
const bool &with_offsets = WordpieceTokenizerOp::kDefWithOffsets)
const bool &with_offsets = TokenizerOp::kDefWithOffsets)
: wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets),
basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token, with_offsets) {}

View File

@ -23,31 +23,18 @@
namespace mindspore {
namespace dataset {
const bool JiebaTokenizerOp::kDefWithOffsets = false;
JiebaTokenizerOp::JiebaTokenizerOp(const std::string &hmm_path, const std::string &dict_path, const JiebaMode &mode,
const bool &with_offsets)
: jieba_mode_(mode), hmm_model_path_(hmm_path), mp_dict_path_(dict_path), with_offsets_(with_offsets) {
: TokenizerOp(with_offsets), jieba_mode_(mode), hmm_model_path_(hmm_path), mp_dict_path_(dict_path) {
jieba_parser_ = std::make_unique<cppjieba::Jieba>(mp_dict_path_, hmm_model_path_, "");
}
Status JiebaTokenizerOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "JiebaTokenizer: input only support one column data.");
RETURN_UNEXPECTED_IF_NULL(jieba_parser_);
if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED("JiebaTokenizer: the input should be scalar with string datatype.");
}
std::string_view sentence_v;
RETURN_IF_NOT_OK(input[0]->GetItemAt(&sentence_v, {}));
Status JiebaTokenizerOp::Tokenize(std::string_view sentence_v, std::vector<std::string> *words,
std::vector<uint32_t> *offsets_start, std::vector<uint32_t> *offsets_limit) {
std::string sentence{sentence_v};
std::vector<std::string> words;
std::vector<uint32_t> offsets_start, offsets_limit;
std::shared_ptr<Tensor> token_tensor, offsets_start_tensor, offsets_limit_tensor;
if (sentence == "") {
words.push_back("");
words->push_back("");
} else {
std::vector<cppjieba::Word> tmp;
if (jieba_mode_ == JiebaMode::kMp) {
@ -62,21 +49,13 @@ Status JiebaTokenizerOp::Compute(const TensorRow &input, TensorRow *output) {
std::make_unique<cppjieba::MixSegment>(jieba_parser_->GetDictTrie(), jieba_parser_->GetHMMModel());
mix_seg->Cut(sentence, tmp, true);
}
GetStringsFromWords(tmp, words);
GetStringsFromWords(tmp, *words);
for (auto item : tmp) {
offsets_start.push_back(static_cast<uint32_t>(item.offset));
offsets_limit.push_back(static_cast<uint32_t>(item.offset + item.word.length()));
offsets_start->push_back(static_cast<uint32_t>(item.offset));
offsets_limit->push_back(static_cast<uint32_t>(item.offset + item.word.length()));
}
}
RETURN_IF_NOT_OK(Tensor::CreateFromVector(words, &token_tensor));
output->push_back(token_tensor);
if (with_offsets_) {
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor));
output->push_back(offsets_start_tensor);
output->push_back(offsets_limit_tensor);
}
return Status::OK();
}

View File

@ -17,21 +17,21 @@
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TEXT_JIEBA_OP_H_
#include <string>
#include <vector>
#include <memory>
#include "cppjieba/Jieba.hpp"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/text/kernels/tokenizer_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class JiebaTokenizerOp : public TensorOp {
class JiebaTokenizerOp : public TokenizerOp {
public:
// default constant for Jieba MPSegment algorithm.
static constexpr size_t MAX_WORD_LENGTH = 512;
// default const for set whether Jieba output offsets tensor.
static const bool kDefWithOffsets;
// Constructor for JiebaTokenizerOp.
// @param hmm_path HMM model file.
// @param mp_path MP model file.
@ -47,7 +47,8 @@ class JiebaTokenizerOp : public TensorOp {
out << Name() << ": " << jieba_mode_ << "hmm_model_path_ " << hmm_model_path_ << "mp_dict_path_" << mp_dict_path_;
}
Status Compute(const TensorRow &input, TensorRow *output) override;
Status Tokenize(std::string_view str, std::vector<std::string> *splits, std::vector<uint32_t> *offsets_start,
std::vector<uint32_t> *offsets_limit) override;
// @word the word to be added to the JiebaTokenizer.
// @freq [Default 0] the frequency fo the word to be added.
@ -61,7 +62,6 @@ class JiebaTokenizerOp : public TensorOp {
std::string mp_dict_path_;
std::unique_ptr<cppjieba::Jieba> jieba_parser_;
JiebaMode jieba_mode_;
bool with_offsets_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {

View File

@ -22,8 +22,6 @@
namespace mindspore {
namespace dataset {
const bool RegexTokenizerOp::kDefWithOffsets = false;
Status RegexTokenizerOp::GetUnicodeSubstr(const icu::UnicodeString &input, const int &start, const int &len,
std::string *out_utf8, icu::UnicodeString *out_unicode) const {
CHECK_FAIL_RETURN_UNEXPECTED((out_utf8 != nullptr || out_unicode != nullptr), "RegexTokenizer: get token failed.");
@ -109,29 +107,10 @@ Status RegexTokenizerOp::GetRegexTokens(const std::string &text, std::vector<std
return Status::OK();
}
Status RegexTokenizerOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "RegexTokenizer: input should be one column data");
if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED(
"RegexTokenizer: the input shape should be scalar and "
"the input datatype should be string.");
}
std::string_view text;
std::vector<std::string> tokens;
std::vector<uint32_t> offsets_start;
std::vector<uint32_t> offsets_limit;
std::shared_ptr<Tensor> token_tensor, offsets_start_tensor, offsets_limit_tensor;
RETURN_IF_NOT_OK(input[0]->GetItemAt(&text, {}));
RETURN_IF_NOT_OK(GetRegexTokens(std::string(text.data(), text.size()), &tokens, &offsets_start, &offsets_limit));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(std::move(tokens), &token_tensor));
output->push_back(token_tensor);
if (with_offsets_) {
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor));
output->push_back(offsets_start_tensor);
output->push_back(offsets_limit_tensor);
}
Status RegexTokenizerOp::Tokenize(std::string_view str, std::vector<std::string> *splits,
std::vector<uint32_t> *offsets_start, std::vector<uint32_t> *offsets_limit) {
RETURN_IF_NOT_OK(GetRegexTokens(std::string(str.data(), str.size()), splits, offsets_start, offsets_limit));
return Status::OK();
}
} // namespace dataset

View File

@ -25,25 +25,25 @@
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/text/kernels/tokenizer_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class RegexTokenizerOp : public TensorOp {
class RegexTokenizerOp : public TokenizerOp {
public:
static const bool kDefWithOffsets;
RegexTokenizerOp(const std::string &delim_pattern, const std::string &keep_delim_pattern,
const bool &with_offsets = kDefWithOffsets)
: delim_pattern_(icu::UnicodeString::fromUTF8(delim_pattern)),
: TokenizerOp(with_offsets),
delim_pattern_(icu::UnicodeString::fromUTF8(delim_pattern)),
keep_delim_pattern_(icu::UnicodeString::fromUTF8(keep_delim_pattern)),
with_offsets_(with_offsets),
keep_delim_(!keep_delim_pattern.empty()) {}
~RegexTokenizerOp() override = default;
Status Compute(const TensorRow &input, TensorRow *output) override;
Status Tokenize(std::string_view str, std::vector<std::string> *splits, std::vector<uint32_t> *offsets_start,
std::vector<uint32_t> *offsets_limit) override;
protected:
Status GetUnicodeSubstr(const icu::UnicodeString &input, const int &start, const int &len, std::string *out_utf8,
@ -56,7 +56,6 @@ class RegexTokenizerOp : public TensorOp {
private:
const icu::UnicodeString delim_pattern_;
const icu::UnicodeString keep_delim_pattern_;
bool with_offsets_;
const bool keep_delim_;
};
} // namespace dataset

View File

@ -25,6 +25,7 @@
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/text/sentence_piece_vocab.h"

View File

@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/text/kernels/tokenizer_op.h"
#include <memory>
#include <string>
#include <string_view>
#include <vector>
namespace mindspore {
namespace dataset {
const bool TokenizerOp::kDefWithOffsets = false;
Status TokenizerOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, Name() + ": input should be one column data.");
if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED(Name() + ": the input shape should be scalar and the input datatype should be string.");
}
std::string_view str;
RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {}));
std::shared_ptr<Tensor> token_tensor, offsets_start_tensor, offsets_limit_tensor;
std::vector<uint32_t> offsets_start, offsets_limit;
std::vector<std::string> splits;
RETURN_IF_NOT_OK(Tokenize(str, &splits, &offsets_start, &offsets_limit));
if (splits.empty()) {
splits.emplace_back("");
offsets_start.push_back(0);
offsets_limit.push_back(0);
}
RETURN_IF_NOT_OK(Tensor::CreateFromVector(splits, &token_tensor));
output->push_back(token_tensor);
if (with_offsets_) {
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor));
output->push_back(offsets_start_tensor);
output->push_back(offsets_limit_tensor);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TOKENIZER_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TOKENIZER_OP_H_
#include <memory>
#include <vector>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class TokenizerOp : public TensorOp {
public:
static const bool kDefWithOffsets;
explicit TokenizerOp(const bool &with_offsets = kDefWithOffsets) : with_offsets_(with_offsets) {}
~TokenizerOp() override = default;
virtual Status Tokenize(std::string_view str, std::vector<std::string> *splits, std::vector<uint32_t> *offsets_start,
std::vector<uint32_t> *offsets_limit) {
return Status::OK();
}
Status Compute(const TensorRow &input, TensorRow *output) override;
protected:
bool with_offsets_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TOKENIZER_OP_H_

View File

@ -45,7 +45,7 @@ Status TruncateSequencePairOp::Compute(const TensorRow &input, TensorRow *output
}
std::shared_ptr<Tensor> outSeq1;
if (length1 != outLength1) {
std::unique_ptr<SliceOp> slice1(new SliceOp(Slice(outLength1 - length1)));
std::unique_ptr<SliceOp> slice1 = std::make_unique<SliceOp>(Slice(outLength1 - length1));
RETURN_IF_NOT_OK(slice1->Compute(seq1, &outSeq1));
} else {
outSeq1 = std::move(seq1);
@ -53,7 +53,7 @@ Status TruncateSequencePairOp::Compute(const TensorRow &input, TensorRow *output
std::shared_ptr<Tensor> outSeq2;
if (length2 != outLength2) {
std::unique_ptr<SliceOp> slice2(new SliceOp(Slice(outLength2 - length2)));
std::unique_ptr<SliceOp> slice2 = std::make_unique<SliceOp>(Slice(outLength2 - length2));
RETURN_IF_NOT_OK(slice2->Compute(seq2, &outSeq2));
} else {
outSeq2 = std::move(seq2);

View File

@ -24,8 +24,6 @@
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/kernels/data/type_cast_op.h"
#include "minddata/dataset/kernels/data/data_utils.h"
namespace mindspore {
namespace dataset {

View File

@ -15,7 +15,9 @@
*/
#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "cppjieba/Unicode.hpp"
@ -26,32 +28,20 @@ using cppjieba::RuneStrArray;
namespace mindspore {
namespace dataset {
const bool UnicodeCharTokenizerOp::kDefWithOffsets = false;
Status UnicodeCharTokenizerOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "UnicodeCharTokenizer: input should be one column data.");
if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED(
"UnicodeCharTokenizer: "
"the input shape should be scalar and the input datatype should be string.");
}
std::string_view str;
RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {}));
Status UnicodeCharTokenizerOp::Tokenize(std::string_view str, std::vector<std::string> *splits,
std::vector<uint32_t> *offsets_start, std::vector<uint32_t> *offsets_limit) {
RuneStrArray runes;
if (!DecodeRunesInString(str.data(), str.size(), runes)) {
RETURN_STATUS_UNEXPECTED("UnicodeCharTokenizer: Decode utf8 string failed.");
}
std::shared_ptr<Tensor> token_tensor, offsets_start_tensor, offsets_limit_tensor;
std::vector<std::string> splits(runes.size());
std::vector<uint32_t> offsets_start, offsets_limit;
std::vector<std::string> words(runes.size());
for (size_t i = 0; i < runes.size(); i++) {
offsets_start.push_back(runes[i].offset);
offsets_limit.push_back(runes[i].offset + runes[i].len);
splits[i] = str.substr(runes[i].offset, runes[i].len);
offsets_start->push_back(runes[i].offset);
offsets_limit->push_back(runes[i].offset + runes[i].len);
words[i] = str.substr(runes[i].offset, runes[i].len);
}
return TokenizerHelper(&splits, &offsets_start, &offsets_limit, with_offsets_, output);
*splits = std::move(words);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -16,30 +16,27 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_
#include <memory>
#include <vector>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/text/kernels/data_utils.h"
#include "minddata/dataset/text/kernels/tokenizer_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class UnicodeCharTokenizerOp : public TensorOp {
class UnicodeCharTokenizerOp : public TokenizerOp {
public:
static const bool kDefWithOffsets;
explicit UnicodeCharTokenizerOp(const bool &with_offsets = kDefWithOffsets) : with_offsets_(with_offsets) {}
explicit UnicodeCharTokenizerOp(const bool &with_offsets = kDefWithOffsets) : TokenizerOp(with_offsets) {}
~UnicodeCharTokenizerOp() override = default;
Status Compute(const TensorRow &input, TensorRow *output) override;
Status Tokenize(std::string_view str, std::vector<std::string> *splits, std::vector<uint32_t> *offsets_start,
std::vector<uint32_t> *offsets_limit) override;
std::string Name() const override { return kUnicodeCharTokenizerOp; }
private:
bool with_offsets_;
};
} // namespace dataset

View File

@ -15,6 +15,7 @@
*/
#include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h"
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
@ -31,30 +32,18 @@ namespace mindspore {
namespace dataset {
const bool UnicodeScriptTokenizerOp::kDefKeepWhitespace = false;
const bool UnicodeScriptTokenizerOp::kDefWithOffsets = false;
Status UnicodeScriptTokenizerOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "UnicodeScriptTokenizer: input should be one column data.");
if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED(
"UnicodeScriptTokenizer: "
"the input shape should be scalar and the input datatype should be string.");
}
std::string_view str;
RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {}));
Status UnicodeScriptTokenizerOp::Tokenize(std::string_view str, std::vector<std::string> *splits,
std::vector<uint32_t> *offsets_start, std::vector<uint32_t> *offsets_limit) {
RuneStrArray runes;
if (!DecodeRunesInString(str.data(), str.size(), runes)) {
RETURN_STATUS_UNEXPECTED("UnicodeScriptTokenizer: Decode utf8 string failed.");
}
std::shared_ptr<Tensor> token_tensor, offsets_start_tensor, offsets_limit_tensor;
UScriptCode last_script = USCRIPT_INVALID_CODE;
icu::ErrorCode status;
int start = 0;
int len = 0;
std::vector<std::string> splits;
std::vector<uint32_t> offsets_start, offsets_limit;
bool was_space = false;
for (size_t i = 0; i < runes.size(); i++) {
@ -71,10 +60,10 @@ Status UnicodeScriptTokenizerOp::Compute(const TensorRow &input, TensorRow *outp
if (len > 0 && (script != last_script || is_space != was_space)) {
// 3) If keep_whitespace_ is false, all the whitespace characters will be discard
if (keep_whitespace_ || !was_space) {
offsets_start.push_back(static_cast<uint32_t>(start));
offsets_limit.push_back(static_cast<uint32_t>(start + len));
offsets_start->push_back(static_cast<uint32_t>(start));
offsets_limit->push_back(static_cast<uint32_t>(start + len));
std::string temp(str.substr(start, len));
splits.emplace_back(std::move(temp));
splits->emplace_back(std::move(temp));
}
start = runes[i].offset;
len = runes[i].len;
@ -86,13 +75,13 @@ Status UnicodeScriptTokenizerOp::Compute(const TensorRow &input, TensorRow *outp
}
if (len > 0 && (keep_whitespace_ || !was_space)) {
offsets_start.push_back(static_cast<uint32_t>(start));
offsets_limit.push_back(static_cast<uint32_t>(start + len));
offsets_start->push_back(static_cast<uint32_t>(start));
offsets_limit->push_back(static_cast<uint32_t>(start + len));
std::string temp(str.substr(start, len));
splits.emplace_back(std::move(temp));
splits->emplace_back(std::move(temp));
}
// 4) If the input is empty scalar string, the output will be 1-D empty string.
return TokenizerHelper(&splits, &offsets_start, &offsets_limit, with_offsets_, output);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -16,34 +16,34 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_
#include <memory>
#include <vector>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/text/kernels/data_utils.h"
#include "minddata/dataset/text/kernels/tokenizer_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class UnicodeScriptTokenizerOp : public TensorOp {
class UnicodeScriptTokenizerOp : public TokenizerOp {
public:
static const bool kDefKeepWhitespace;
static const bool kDefWithOffsets;
explicit UnicodeScriptTokenizerOp(const bool &keep_whitespace = kDefKeepWhitespace,
const bool &with_offsets = kDefWithOffsets)
: keep_whitespace_(keep_whitespace), with_offsets_(with_offsets) {}
: TokenizerOp(with_offsets), keep_whitespace_(keep_whitespace) {}
~UnicodeScriptTokenizerOp() override = default;
Status Compute(const TensorRow &input, TensorRow *output) override;
Status Tokenize(std::string_view str, std::vector<std::string> *splits, std::vector<uint32_t> *offsets_start,
std::vector<uint32_t> *offsets_limit) override;
std::string Name() const override { return kUnicodeScriptTokenizerOp; }
private:
bool keep_whitespace_; // If or not keep whitespace tokens
bool with_offsets_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -14,7 +14,6 @@
* limitations under the License.
*/
#include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h"
#include <memory>
#include <string_view>
#include <utility>
#include <vector>
@ -28,35 +27,22 @@ using cppjieba::RuneStrArray;
namespace mindspore {
namespace dataset {
const bool WhitespaceTokenizerOp::kDefWithOffsets = false;
Status WhitespaceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "WhitespaceTokenizer: input should be one column data.");
if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED(
"WhitespaceTokenizer: the input shape should be scalar and the input datatype should be string.");
}
std::string_view str;
RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {}));
Status WhitespaceTokenizerOp::Tokenize(std::string_view str, std::vector<std::string> *splits,
std::vector<uint32_t> *offsets_start, std::vector<uint32_t> *offsets_limit) {
RuneStrArray runes;
if (!DecodeRunesInString(str.data(), str.size(), runes)) {
RETURN_STATUS_UNEXPECTED("WhitespaceTokenizer: Decode utf8 string failed.");
}
std::vector<uint32_t> offsets_start, offsets_limit;
std::vector<std::string> splits;
int start = 0;
int len = 0;
for (size_t i = 0; i < runes.size(); i++) {
if (u_isUWhiteSpace(runes[i].rune)) {
if (len > 0) {
offsets_start.push_back(static_cast<uint32_t>(start));
offsets_limit.push_back(static_cast<uint32_t>(start + len));
offsets_start->push_back(static_cast<uint32_t>(start));
offsets_limit->push_back(static_cast<uint32_t>(start + len));
std::string temp(str.substr(start, len));
splits.emplace_back(std::move(temp));
splits->emplace_back(std::move(temp));
len = 0;
}
} else {
@ -67,12 +53,17 @@ Status WhitespaceTokenizerOp::Compute(const TensorRow &input, TensorRow *output)
}
}
if (len > 0) {
offsets_start.push_back(static_cast<uint32_t>(start));
offsets_limit.push_back(static_cast<uint32_t>(start + len));
offsets_start->push_back(static_cast<uint32_t>(start));
offsets_limit->push_back(static_cast<uint32_t>(start + len));
std::string temp(str.substr(start, len));
splits.emplace_back(std::move(temp));
splits->emplace_back(std::move(temp));
}
return TokenizerHelper(&splits, &offsets_start, &offsets_limit, with_offsets_, output);
if (splits->empty()) {
splits->emplace_back("");
offsets_start->push_back(0);
offsets_limit->push_back(0);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -16,30 +16,28 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_
#include <memory>
#include <vector>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/text/kernels/data_utils.h"
#include "minddata/dataset/text/kernels/tokenizer_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class WhitespaceTokenizerOp : public TensorOp {
class WhitespaceTokenizerOp : public TokenizerOp {
public:
static const bool kDefWithOffsets;
explicit WhitespaceTokenizerOp(const bool &with_offsets = kDefWithOffsets) : with_offsets_(with_offsets) {}
explicit WhitespaceTokenizerOp(const bool &with_offsets = kDefWithOffsets) : TokenizerOp(with_offsets) {}
~WhitespaceTokenizerOp() override = default;
Status Compute(const TensorRow &input, TensorRow *output) override;
Status Tokenize(std::string_view str, std::vector<std::string> *splits, std::vector<uint32_t> *offsets_start,
std::vector<uint32_t> *offsets_limit) override;
std::string Name() const override { return kWhitespaceTokenizerOp; }
private:
bool with_offsets_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -22,16 +22,15 @@ namespace dataset {
const char WordpieceTokenizerOp::kDefSuffixIndicator[] = "##";
const int WordpieceTokenizerOp::kDefMaxBytesPerToken = 100;
const char WordpieceTokenizerOp::kDefUnknownToken[] = "[UNK]";
const bool WordpieceTokenizerOp::kDefWithOffsets = false;
WordpieceTokenizerOp::WordpieceTokenizerOp(const std::shared_ptr<Vocab> &vocab, const std::string &suffix_indicator,
const int &max_bytes_per_token, const std::string &unknown_token,
const bool &with_offsets)
: vocab_(vocab),
: TokenizerOp(with_offsets),
vocab_(vocab),
suffix_indicator_(suffix_indicator),
max_bytes_per_token_(max_bytes_per_token),
unknown_token_(unknown_token),
with_offsets_(with_offsets) {}
unknown_token_(unknown_token) {}
Status WordpieceTokenizerOp::LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start,
bool *out_found, int *out_end) const {

View File

@ -24,6 +24,7 @@
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/text/kernels/tokenizer_op.h"
#include "minddata/dataset/text/vocab.h"
#include "minddata/dataset/util/status.h"
@ -32,12 +33,11 @@ using cppjieba::RuneStrArray;
namespace mindspore {
namespace dataset {
class WordpieceTokenizerOp : public TensorOp {
class WordpieceTokenizerOp : public TokenizerOp {
public:
static const char kDefSuffixIndicator[];
static const int kDefMaxBytesPerToken;
static const char kDefUnknownToken[];
static const bool kDefWithOffsets;
WordpieceTokenizerOp(const std::shared_ptr<Vocab> &vocab, const std::string &suffix_indicator = kDefSuffixIndicator,
const int &max_bytes_per_token = kDefMaxBytesPerToken,
const std::string &unknown_token = kDefUnknownToken, const bool &with_offsets = kDefWithOffsets);
@ -61,7 +61,6 @@ class WordpieceTokenizerOp : public TensorOp {
private:
const std::shared_ptr<Vocab> vocab_;
const std::string suffix_indicator_;
const bool with_offsets_;
const int max_bytes_per_token_;
const std::string unknown_token_;
};

View File

@ -2538,7 +2538,6 @@ class MapDataset(Dataset):
# If output_columns were not provided then use input_columns
self.output_columns = self.input_columns if not self.output_columns else self.output_columns
# todo(crc): move to @check_map
if self.input_columns and self.output_columns \
and len(self.input_columns) != len(self.output_columns) \
and not self.column_order:
@ -3237,8 +3236,8 @@ class MnistDataset(MappableDataset):
"""
@check_mnist_cifar_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None,
num_shards=None, shard_id=None, cache=None):
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None,
sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
@ -4129,8 +4128,8 @@ class Cifar10Dataset(MappableDataset):
"""
@check_mnist_cifar_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None,
num_shards=None, shard_id=None, cache=None):
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None,
sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
@ -4233,8 +4232,8 @@ class Cifar100Dataset(MappableDataset):
"""
@check_mnist_cifar_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None,
num_shards=None, shard_id=None, cache=None):
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None,
sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
@ -4798,8 +4797,8 @@ class CelebADataset(MappableDataset):
def parse(self, children=None):
if self.usage != "all":
dir = os.path.realpath(self.dataset_dir)
partition_file = os.path.join(dir, "list_eval_partition.txt")
dataset_dir = os.path.realpath(self.dataset_dir)
partition_file = os.path.join(dataset_dir, "list_eval_partition.txt")
if os.path.exists(partition_file) is False:
raise RuntimeError("Partition file can not be found when usage is not 'all'.")
return cde.CelebANode(self.dataset_dir, self.usage, self.sampler, self.decode, self.extensions)
@ -4867,82 +4866,7 @@ class CLUEDataset(SourceDataset):
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_files = self._find_files(dataset_files)
self.task_dict = {
'AFQMC': {
'train': {
'sentence1': 'sentence1', 'sentence2': 'sentence2', 'label': 'label'
},
'test': {
'id': 'id', 'sentence1': 'sentence1', 'sentence2': 'sentence2'
},
'eval': {
'sentence1': 'sentence1', 'sentence2': 'sentence2', 'label': 'label'
}
},
'CMNLI': {
'train': {
'sentence1': 'sentence1', 'sentence2': 'sentence2', 'label': 'label'
},
'test': {
'id': 'id', 'sentence1': 'sentence1', 'sentence2': 'sentence2'
},
'eval': {
'sentence1': 'sentence1', 'sentence2': 'sentence2', 'label': 'label'
}
},
'CSL': {
'train': {
'id': 'id', 'abst': 'abst', 'keyword': 'keyword', 'label': 'label'
},
'test': {
'id': 'id', 'abst': 'abst', 'keyword': 'keyword'
},
'eval': {
'id': 'id', 'abst': 'abst', 'keyword': 'keyword', 'label': 'label'
}
},
'IFLYTEK': {
'train': {
'label': 'label', 'label_des': 'label_des', 'sentence': 'sentence'
},
'test': {
'id': 'id', 'sentence': 'sentence',
},
'eval': {
'label': 'label', 'label_des': 'label_des', 'sentence': 'sentence'
}
},
'TNEWS': {
'train': {
'label': 'label', 'label_desc': 'label_desc', 'sentence': 'sentence', 'keywords': 'keywords'
},
'test': {
'id': 'id', 'sentence': 'sentence', 'keywords': 'keywords'
},
'eval': {
'label': 'label', 'label_desc': 'label_desc', 'sentence': 'sentence', 'keywords': 'keywords'
}
},
'WSC': {
'train': {
'span1_index': 'target/span1_index', 'span2_index': 'target/span2_index',
'span1_text': 'target/span1_text', 'span2_text': 'target/span2_text', 'idx': 'idx',
'label': 'label', 'text': 'text'
},
'test': {
'span1_index': 'target/span1_index', 'span2_index': 'target/span2_index',
'span1_text': 'target/span1_text', 'span2_text': 'target/span2_text', 'idx': 'idx', 'text': 'text'
},
'eval': {
'span1_index': 'target/span1_index', 'span2_index': 'target/span2_index',
'span1_text': 'target/span1_text', 'span2_text': 'target/span2_text', 'idx': 'idx',
'label': 'label', 'text': 'text'
}
}
}
self.usage = replace_none(usage, 'train')
self.cols_to_keyword = self.task_dict[task][self.usage]
self.task = replace_none(task, 'AFQMC')
def parse(self, children=None):
@ -5047,7 +4971,8 @@ class TextFileDataset(SourceDataset):
self.dataset_files.sort()
def parse(self, children=None):
return cde.TextFileNode(self.dataset_files, self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id)
return cde.TextFileNode(self.dataset_files, self.num_samples, self.shuffle_flag, self.num_shards,
self.shard_id)
class _NumpySlicesDataset:

View File

@ -45,7 +45,7 @@ def test_jieba_callable():
# test input multiple tensors
with pytest.raises(RuntimeError) as info:
_ = jieba_op1(text1, text2)
assert "JiebaTokenizer: input only support one column data." in str(info.value)
assert "JiebaTokenizerOp: input should be one column data." in str(info.value)
def test_jieba_1():