forked from mindspore-Ecosystem/mindspore
!5861 [MD] Change return val of GetNextRow in c-api
Merge pull request !5861 from luoyang/c-api-pyfunc
This commit is contained in:
commit
2ff6dd3b77
|
@ -327,9 +327,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
|
|||
|
||||
// Finish building vocab by triggering GetNextRow
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
if (vocab->vocab().empty()) {
|
||||
MS_LOG(ERROR) << "Fail to build vocab.";
|
||||
if (!iter->GetNextRow(&row)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -1782,7 +1780,7 @@ bool BuildVocabDataset::ValidateParams() {
|
|||
MS_LOG(ERROR) << "BuildVocab: vocab is null.";
|
||||
return false;
|
||||
}
|
||||
if (top_k_ < 0) {
|
||||
if (top_k_ <= 0) {
|
||||
MS_LOG(ERROR) << "BuildVocab: top_k shoule be positive, but got: " << top_k_;
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -22,25 +22,29 @@ namespace dataset {
|
|||
namespace api {
|
||||
|
||||
// Get the next row from the data pipeline.
|
||||
void Iterator::GetNextRow(TensorMap *row) {
|
||||
bool Iterator::GetNextRow(TensorMap *row) {
|
||||
Status rc = iterator_->GetNextAsMap(row);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc;
|
||||
row->clear();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Get the next row from the data pipeline.
|
||||
void Iterator::GetNextRow(TensorVec *row) {
|
||||
bool Iterator::GetNextRow(TensorVec *row) {
|
||||
TensorRow tensor_row;
|
||||
Status rc = iterator_->FetchNextTensorRow(&tensor_row);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc;
|
||||
row->clear();
|
||||
return false;
|
||||
}
|
||||
// Generate a vector as return
|
||||
row->clear();
|
||||
std::copy(tensor_row.begin(), tensor_row.end(), std::back_inserter(*row));
|
||||
return true;
|
||||
}
|
||||
|
||||
// Shut down the data pipeline.
|
||||
|
|
|
@ -56,12 +56,14 @@ class Iterator {
|
|||
/// \brief Function to get the next row from the data pipeline.
|
||||
/// \note Type of return data is a map(with column name).
|
||||
/// \param[out] row - the output tensor row.
|
||||
void GetNextRow(TensorMap *row);
|
||||
/// \return Returns true if no error encountered else false.
|
||||
bool GetNextRow(TensorMap *row);
|
||||
|
||||
/// \brief Function to get the next row from the data pipeline.
|
||||
/// \note Type of return data is a vector(without column name).
|
||||
/// \param[out] row - the output tensor row.
|
||||
void GetNextRow(TensorVec *row);
|
||||
/// \return Returns true if no error encountered else false.
|
||||
bool GetNextRow(TensorVec *row);
|
||||
|
||||
/// \brief Function to shut down the data pipeline.
|
||||
void Stop();
|
||||
|
|
|
@ -92,7 +92,7 @@ Status Vocab::BuildFromVector(const std::vector<WordType> &words, const std::vec
|
|||
for (const WordType &word : words) {
|
||||
if (std::count(words.begin(), words.end(), word) > 1) {
|
||||
if (duplicate_word.find(word) == std::string::npos) {
|
||||
duplicate_word = duplicate_word + ", " + word;
|
||||
duplicate_word = duplicate_word.empty() ? duplicate_word + word : duplicate_word + ", " + word;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -102,10 +102,16 @@ Status Vocab::BuildFromVector(const std::vector<WordType> &words, const std::vec
|
|||
}
|
||||
|
||||
std::string duplicate_sp;
|
||||
std::string existed_sp;
|
||||
for (const WordType &sp : special_tokens) {
|
||||
if (std::count(special_tokens.begin(), special_tokens.end(), sp) > 1) {
|
||||
if (duplicate_sp.find(sp) == std::string::npos) {
|
||||
duplicate_sp = duplicate_sp + ", " + sp;
|
||||
duplicate_sp = duplicate_sp.empty() ? duplicate_sp + sp : duplicate_sp + ", " + sp;
|
||||
}
|
||||
}
|
||||
if (std::count(words.begin(), words.end(), sp) >= 1) {
|
||||
if (existed_sp.find(sp) == std::string::npos) {
|
||||
existed_sp = existed_sp.empty() ? existed_sp + sp : existed_sp + ", " + sp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -113,6 +119,10 @@ Status Vocab::BuildFromVector(const std::vector<WordType> &words, const std::vec
|
|||
MS_LOG(ERROR) << "special_tokens contains duplicate word: " << duplicate_sp;
|
||||
RETURN_STATUS_UNEXPECTED("special_tokens contains duplicate word: " + duplicate_sp);
|
||||
}
|
||||
if (!existed_sp.empty()) {
|
||||
MS_LOG(ERROR) << "special_tokens and word_list contain duplicate word: " << existed_sp;
|
||||
RETURN_STATUS_UNEXPECTED("special_tokens and word_list contain duplicate word: " + existed_sp);
|
||||
}
|
||||
|
||||
std::unordered_map<WordType, WordIdType> word2id;
|
||||
|
||||
|
@ -151,7 +161,7 @@ Status Vocab::BuildFromFileCpp(const std::string &path, const std::string &delim
|
|||
for (const WordType &sp : special_tokens) {
|
||||
if (std::count(special_tokens.begin(), special_tokens.end(), sp) > 1) {
|
||||
if (duplicate_sp.find(sp) == std::string::npos) {
|
||||
duplicate_sp = duplicate_sp + ", " + sp;
|
||||
duplicate_sp = duplicate_sp.empty() ? duplicate_sp + sp : duplicate_sp + ", " + sp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -179,12 +189,12 @@ Status Vocab::BuildFromFileCpp(const std::string &path, const std::string &delim
|
|||
word = word.substr(0, word.find_first_of(delimiter));
|
||||
}
|
||||
if (word2id.find(word) != word2id.end()) {
|
||||
MS_LOG(ERROR) << "duplicate word:" + word + ".";
|
||||
RETURN_STATUS_UNEXPECTED("duplicate word:" + word + ".");
|
||||
MS_LOG(ERROR) << "word_list contains duplicate word:" + word + ".";
|
||||
RETURN_STATUS_UNEXPECTED("word_list contains duplicate word:" + word + ".");
|
||||
}
|
||||
if (specials.find(word) != specials.end()) {
|
||||
MS_LOG(ERROR) << word + " is already in special_tokens.";
|
||||
RETURN_STATUS_UNEXPECTED(word + " is already in special_tokens.");
|
||||
MS_LOG(ERROR) << "special_tokens and word_list contain duplicate word: " << word;
|
||||
RETURN_STATUS_UNEXPECTED("special_tokens and word_list contain duplicate word: " + word);
|
||||
}
|
||||
word2id[word] = word_id++;
|
||||
// break if enough row is read, if vocab_size is smaller than 0
|
||||
|
|
|
@ -158,7 +158,7 @@ TEST_F(MindDataTestVocab, TestVocabFromEmptyVector) {
|
|||
|
||||
TEST_F(MindDataTestVocab, TestVocabFromVectorFail1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromVectorFail1.";
|
||||
// Build vocab from a vector of words with no special tokens
|
||||
// Build vocab from a vector of words
|
||||
std::vector<std::string> list = {"apple", "apple", "cat", "cat", "egg"};
|
||||
std::vector<std::string> sp_tokens = {};
|
||||
std::shared_ptr<Vocab> vocab = std::make_shared<Vocab>();
|
||||
|
@ -170,7 +170,7 @@ TEST_F(MindDataTestVocab, TestVocabFromVectorFail1) {
|
|||
|
||||
TEST_F(MindDataTestVocab, TestVocabFromVectorFail2) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromVectorFail2.";
|
||||
// Build vocab from a vector of words with no special tokens
|
||||
// Build vocab from a vector
|
||||
std::vector<std::string> list = {"apple", "dog", "egg"};
|
||||
std::vector<std::string> sp_tokens = {"<pad>", "<unk>", "<pad>", "<unk>", "<none>"};
|
||||
std::shared_ptr<Vocab> vocab = std::make_shared<Vocab>();
|
||||
|
@ -180,6 +180,18 @@ TEST_F(MindDataTestVocab, TestVocabFromVectorFail2) {
|
|||
EXPECT_NE(s, Status::OK());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestVocab, TestVocabFromVectorFail3) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromVectorFail3.";
|
||||
// Build vocab from a vector
|
||||
std::vector<std::string> list = {"apple", "dog", "egg", "<unk>", "<pad>"};
|
||||
std::vector<std::string> sp_tokens = {"<pad>", "<unk>"};
|
||||
std::shared_ptr<Vocab> vocab = std::make_shared<Vocab>();
|
||||
|
||||
// Expected failure: special tokens are already existed in word_list
|
||||
Status s = Vocab::BuildFromVector(list, sp_tokens, true, &vocab);
|
||||
EXPECT_NE(s, Status::OK());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestVocab, TestVocabFromFile) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromFile.";
|
||||
// Build vocab from local file
|
||||
|
@ -218,8 +230,8 @@ TEST_F(MindDataTestVocab, TestVocabFromFileFail2) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestVocab, TestVocabFromFileFail3) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromFileFail2.";
|
||||
// Build vocab from local file which is not exist
|
||||
MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromFileFail3.";
|
||||
// Build vocab from local file
|
||||
std::string vocab_dir = datasets_root_path_ + "/testVocab/vocab_list.txt";
|
||||
std::shared_ptr<Vocab> vocab = std::make_shared<Vocab>();
|
||||
|
||||
|
@ -227,3 +239,14 @@ TEST_F(MindDataTestVocab, TestVocabFromFileFail3) {
|
|||
Status s = Vocab::BuildFromFileCpp(vocab_dir, ",", -1, {"<unk>", "<unk>"}, true, &vocab);
|
||||
EXPECT_NE(s, Status::OK());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestVocab, TestVocabFromFileFail4) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromFileFail4.";
|
||||
// Build vocab from local file
|
||||
std::string vocab_dir = datasets_root_path_ + "/testVocab/vocab_list.txt";
|
||||
std::shared_ptr<Vocab> vocab = std::make_shared<Vocab>();
|
||||
|
||||
// Expected failure: special_tokens and word_list contain duplicate word
|
||||
Status s = Vocab::BuildFromFileCpp(vocab_dir, ",", -1, {"home"}, true, &vocab);
|
||||
EXPECT_NE(s, Status::OK());
|
||||
}
|
||||
|
|
|
@ -271,6 +271,21 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail3) {
|
|||
EXPECT_EQ(vocab, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail4) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabFromDatasetFail4.";
|
||||
|
||||
// Create a TextFile dataset
|
||||
std::string data_file = datasets_root_path_ + "/testVocab/words.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({data_file}, 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create vocab from dataset
|
||||
// Expected failure: special tokens are already in the dataset
|
||||
std::shared_ptr<Vocab> vocab = ds->BuildVocab({"text"}, {0, std::numeric_limits<int64_t>::max()},
|
||||
std::numeric_limits<int64_t>::max(), {"world"});
|
||||
EXPECT_EQ(vocab, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestVocabFromDatasetInt64) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabFromDatasetInt64.";
|
||||
|
||||
|
@ -318,4 +333,4 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetInt64) {
|
|||
iter->GetNextRow(&row);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue