From 9064a5f7319f215e12652a0784ae957b6c9783d0 Mon Sep 17 00:00:00 2001 From: luoyang Date: Thu, 17 Sep 2020 14:35:47 +0800 Subject: [PATCH] add num_row check & fix api distributed sampler, lookup, default value of usage --- .../ccsrc/minddata/dataset/api/datasets.cc | 18 +++-- .../ccsrc/minddata/dataset/api/samplers.cc | 11 ++-- mindspore/ccsrc/minddata/dataset/api/text.cc | 11 +--- .../ccsrc/minddata/dataset/core/tensor.cc | 1 + .../minddata/dataset/core/tensor_shape.cc | 2 +- .../engine/datasetops/source/album_op.cc | 4 ++ .../engine/datasetops/source/celeba_op.cc | 3 +- .../engine/datasetops/source/cifar_op.cc | 2 +- .../engine/datasetops/source/clue_op.cc | 5 +- .../engine/datasetops/source/coco_op.cc | 4 ++ .../engine/datasetops/source/csv_op.cc | 3 +- .../datasetops/source/image_folder_op.cc | 4 +- .../engine/datasetops/source/manifest_op.cc | 3 +- .../engine/datasetops/source/mnist_op.cc | 3 +- .../engine/datasetops/source/text_file_op.cc | 3 +- .../engine/datasetops/source/tf_reader_op.cc | 3 +- .../engine/datasetops/source/voc_op.cc | 4 ++ .../ccsrc/minddata/dataset/include/datasets.h | 16 ++--- .../ccsrc/minddata/dataset/include/samplers.h | 6 +- .../ccsrc/minddata/dataset/text/vocab.cc | 45 +++---------- tests/ut/cpp/dataset/build_vocab_test.cc | 4 +- .../cpp/dataset/c_api_dataset_cifar_test.cc | 24 +++++-- .../dataset/c_api_dataset_iterator_test.cc | 8 +-- .../ut/cpp/dataset/c_api_dataset_ops_test.cc | 28 ++++---- tests/ut/cpp/dataset/c_api_dataset_vocab.cc | 65 +++++++++++++------ tests/ut/cpp/dataset/c_api_datasets_test.cc | 4 +- tests/ut/cpp/dataset/c_api_transforms_test.cc | 20 +++--- 27 files changed, 160 insertions(+), 144 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 8dfc2a142d..648d566935 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -16,6 +16,7 @@ #include #include +#include #include "minddata/dataset/include/datasets.h" #include "minddata/dataset/include/samplers.h" #include "minddata/dataset/include/transforms.h" @@ -729,7 +730,14 @@ bool ValidateDatasetSampler(const std::string &dataset_name, const std::shared_p } bool ValidateStringValue(const std::string &str, const std::unordered_set &valid_strings) { - return valid_strings.find(str) != valid_strings.end(); + if (valid_strings.find(str) == valid_strings.end()) { + std::string mode; + mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode, + [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); }); + MS_LOG(ERROR) << str << " does not match any mode in [" + mode + " ]"; + return false; + } + return true; } // Helper function to validate dataset input/output column parameter @@ -841,8 +849,7 @@ Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, const std::string bool Cifar10Dataset::ValidateParams() { return ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_) && - ValidateDatasetSampler("Cifar10Dataset", sampler_) && - ValidateStringValue(usage_, {"train", "test", "all", ""}); + ValidateDatasetSampler("Cifar10Dataset", sampler_) && ValidateStringValue(usage_, {"train", "test", "all"}); } // Function to build CifarOp for Cifar10 @@ -870,8 +877,7 @@ Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, const std::stri bool Cifar100Dataset::ValidateParams() { return ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_) && - ValidateDatasetSampler("Cifar100Dataset", sampler_) && - ValidateStringValue(usage_, {"train", "test", "all", ""}); + ValidateDatasetSampler("Cifar100Dataset", sampler_) && ValidateStringValue(usage_, {"train", "test", "all"}); } // Function to build CifarOp for Cifar100 @@ -1359,7 +1365,7 @@ MnistDataset::MnistDataset(std::string dataset_dir, std::string usage, std::shar : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} bool MnistDataset::ValidateParams() { - return ValidateStringValue(usage_, {"train", "test", "all", ""}) && + return ValidateStringValue(usage_, {"train", "test", "all"}) && ValidateDatasetDirParam("MnistDataset", dataset_dir_) && ValidateDatasetSampler("MnistDataset", sampler_); } diff --git a/mindspore/ccsrc/minddata/dataset/api/samplers.cc b/mindspore/ccsrc/minddata/dataset/api/samplers.cc index 75c2c6bcc1..56ad874a65 100644 --- a/mindspore/ccsrc/minddata/dataset/api/samplers.cc +++ b/mindspore/ccsrc/minddata/dataset/api/samplers.cc @@ -31,8 +31,10 @@ SamplerObj::SamplerObj() {} /// Function to create a Distributed Sampler. std::shared_ptr DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, - int64_t num_samples, uint32_t seed, bool even_dist) { - auto sampler = std::make_shared(num_shards, shard_id, shuffle, num_samples, seed, even_dist); + int64_t num_samples, uint32_t seed, int64_t offset, + bool even_dist) { + auto sampler = + std::make_shared(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist); // Input validation if (!sampler->ValidateParams()) { return nullptr; @@ -95,12 +97,13 @@ std::shared_ptr WeightedRandomSampler(std::vector DistributedSamplerObj::Build() { - return std::make_shared(num_samples_, num_shards_, shard_id_, shuffle_, seed_, + return std::make_shared(num_samples_, num_shards_, shard_id_, shuffle_, seed_, offset_, even_dist_); } diff --git a/mindspore/ccsrc/minddata/dataset/api/text.cc b/mindspore/ccsrc/minddata/dataset/api/text.cc index 594a4410cf..31247d1b7b 100644 --- a/mindspore/ccsrc/minddata/dataset/api/text.cc +++ b/mindspore/ccsrc/minddata/dataset/api/text.cc @@ -42,15 +42,10 @@ bool LookupOperation::ValidateParams() { MS_LOG(ERROR) << "Lookup: vocab object type is incorrect or null."; return false; } - if (unknown_token_.empty()) { - MS_LOG(ERROR) << "Lookup: no unknown token is specified."; + default_id_ = vocab_->Lookup(unknown_token_); + if (default_id_ == Vocab::kNoTokenExists) { + MS_LOG(ERROR) << "Lookup: " << unknown_token_ << " doesn't exist in vocab."; return false; - } else { - default_id_ = vocab_->Lookup(unknown_token_); - if (default_id_ == Vocab::kNoTokenExists) { - MS_LOG(ERROR) << "Lookup: unknown_token: [" + unknown_token_ + "], does not exist in vocab."; - return false; - } } return true; } diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor.cc b/mindspore/ccsrc/minddata/dataset/core/tensor.cc index 2c7bbb5b51..ee3694dad0 100644 --- a/mindspore/ccsrc/minddata/dataset/core/tensor.cc +++ b/mindspore/ccsrc/minddata/dataset/core/tensor.cc @@ -263,6 +263,7 @@ Status Tensor::CreateFromFile(const std::string &path, std::shared_ptr * fs.open(path, std::ios::binary | std::ios::in); CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + path); int64_t num_bytes = fs.seekg(0, std::ios::end).tellg(); + CHECK_FAIL_RETURN_UNEXPECTED(num_bytes <= kDeMaxDim, "Invalid file to allocate tensor memory, check path: " + path); CHECK_FAIL_RETURN_UNEXPECTED(fs.seekg(0, std::ios::beg).good(), "Fail to find size of file"); RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape{num_bytes}, DataType(DataType::DE_UINT8), out)); int64_t written_bytes = fs.read(reinterpret_cast((*out)->GetMutableBuffer()), num_bytes).gcount(); diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor_shape.cc b/mindspore/ccsrc/minddata/dataset/core/tensor_shape.cc index 19c3a6b457..51fa874ea7 100644 --- a/mindspore/ccsrc/minddata/dataset/core/tensor_shape.cc +++ b/mindspore/ccsrc/minddata/dataset/core/tensor_shape.cc @@ -158,7 +158,7 @@ void TensorShape::AddListToShape(const T &list) { } if (dim > kDeMaxDim) { std::stringstream ss; - ss << "Invalid shape data, dim (" << size << ") is larger than the maximum dim size(" << kDeMaxDim << ")!"; + ss << "Invalid shape data, dim (" << dim << ") is larger than the maximum dim size(" << kDeMaxDim << ")!"; MS_LOG(ERROR) << ss.str().c_str(); known_ = false; raw_shape_.clear(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc index 559cb1457f..33c0dbef9e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc @@ -119,6 +119,10 @@ Status AlbumOp::PrescanEntry() { std::sort(image_rows_.begin(), image_rows_.end(), StrComp); num_rows_ = image_rows_.size(); + if (num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "Invalid data, no valid data matching the dataset API AlbumDataset. Please check file path or dataset API."); + } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc index 8ff2cf5bf6..f703473c36 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc @@ -237,8 +237,7 @@ Status CelebAOp::ParseImageAttrInfo() { num_rows_ = image_labels_vec_.size(); if (num_rows_ == 0) { RETURN_STATUS_UNEXPECTED( - "Invalid data, no valid data matching the dataset API CelebADataset. " - "Please check file path or dataset API validation first"); + "Invalid data, no valid data matching the dataset API CelebADataset. Please check file path or dataset API."); } MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_ << "."; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc index 6a32f2c17d..9d4568041a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc @@ -412,7 +412,7 @@ Status CifarOp::ParseCifarData() { if (num_rows_ == 0) { std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset"; RETURN_STATUS_UNEXPECTED("Invalid data, no valid data matching the dataset API " + api + - ". Please check file path or dataset API validation first."); + ". Please check file path or dataset API."); } cifar_raw_data_block_->Reset(); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc index 72983c73b4..73ed6d6500 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc @@ -192,7 +192,7 @@ Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, con js = nlohmann::json::parse(line); } catch (const std::exception &err) { // Catch any exception and convert to Status return code - RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse json file: " + line); + RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse json file: " + file); } int cols_count = cols_to_keyword_.size(); TensorRow tRow(cols_count, nullptr); @@ -482,8 +482,7 @@ Status ClueOp::CalculateNumRowsPerShard() { } if (all_num_rows_ == 0) { RETURN_STATUS_UNEXPECTED( - "Invalid data, no valid data matching the dataset API CLUEDataset. Please check file path or dataset API " - "validation first."); + "Invalid data, no valid data matching the dataset API CLUEDataset. Please check file path or dataset API."); } num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc index f3c77bce45..99b45e53f9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc @@ -468,6 +468,10 @@ Status CocoOp::ParseAnnotationIds() { if (coordinate_map_.find(img) != coordinate_map_.end()) image_ids_.push_back(img); } num_rows_ = image_ids_.size(); + if (num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "Invalid data, no valid data matching the dataset API CocoDataset. Please check file path or dataset API."); + } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc index 95983ca069..4192653e89 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -780,8 +780,7 @@ Status CsvOp::CalculateNumRowsPerShard() { } if (all_num_rows_ == 0) { RETURN_STATUS_UNEXPECTED( - "Invalid data, no valid data matching the dataset API CsvDataset. Please check file path or CSV format " - "validation first."); + "Invalid data, no valid data matching the dataset API CsvDataset. Please check file path or CSV format."); } num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc index 43121d2a29..8ffabb7260 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc @@ -117,8 +117,8 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) { num_rows_ = image_label_pairs_.size(); if (num_rows_ == 0) { RETURN_STATUS_UNEXPECTED( - "Invalid data, no valid data matching the dataset API ImageFolderDataset. Please check file path or dataset " - "API validation first."); + "Invalid data, no valid data matching the dataset API ImageFolderDataset. " + "Please check file path or dataset API."); } // free memory of two queues used for pre-scan folder_name_queue_->Reset(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc index 6b675d0112..0c32ed96b6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc @@ -386,8 +386,7 @@ Status ManifestOp::CountDatasetInfo() { num_rows_ = static_cast(image_labelname_.size()); if (num_rows_ == 0) { RETURN_STATUS_UNEXPECTED( - "Invalid data, no valid data matching the dataset API ManifestDataset.Please check file path or dataset API " - "validation first."); + "Invalid data, no valid data matching the dataset API ManifestDataset. Please check file path or dataset API."); } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc index af5d434a8e..c3f8c62e8a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc @@ -369,8 +369,7 @@ Status MnistOp::ParseMnistData() { num_rows_ = image_label_pairs_.size(); if (num_rows_ == 0) { RETURN_STATUS_UNEXPECTED( - "Invalid data, no valid data matching the dataset API MnistDataset.Please check file path or dataset API " - "validation first."); + "Invalid data, no valid data matching the dataset API MnistDataset. Please check file path or dataset API."); } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc index 7bc189eaf3..d9161182b9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc @@ -473,8 +473,7 @@ Status TextFileOp::CalculateNumRowsPerShard() { } if (all_num_rows_ == 0) { RETURN_STATUS_UNEXPECTED( - "Invalid data, no valid data matching the dataset API TextFileDataset.Please check file path or dataset API " - "validation first."); + "Invalid data, no valid data matching the dataset API TextFileDataset. Please check file path or dataset API."); } num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc index b112eb4138..79d39a0fd0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc @@ -229,8 +229,7 @@ Status TFReaderOp::CalculateNumRowsPerShard() { num_rows_per_shard_ = static_cast(std::ceil(num_rows_ * 1.0 / num_devices_)); if (num_rows_per_shard_ == 0) { RETURN_STATUS_UNEXPECTED( - "Invalid data, no valid data matching the dataset API TFRecordDataset.Please check file path or dataset API " - "validation first."); + "Invalid data, no valid data matching the dataset API TFRecordDataset. Please check file path or dataset API."); } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc index 179b92d5bf..eeb53ab825 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc @@ -315,6 +315,10 @@ Status VOCOp::ParseAnnotationIds() { } num_rows_ = image_ids_.size(); + if (num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "Invalid data, no valid data matching the dataset API VOCDataset. Please check file path or dataset API."); + } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index b65ab6a9eb..e9b6160240 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -113,7 +113,7 @@ std::shared_ptr Album(const std::string &dataset_dir, const std::s /// \notes The generated dataset has two columns ['image', 'attr']. /// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type. /// \param[in] dataset_dir Path to the root directory that contains the dataset. -/// \param[in] usage One of "all", "train", "valid" or "test". +/// \param[in] usage One of "all", "train", "valid" or "test" (default = "all"). /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// \param[in] decode Decode the images after reading (default=false). @@ -126,21 +126,21 @@ std::shared_ptr CelebA(const std::string &dataset_dir, const std: /// \brief Function to create a Cifar10 Dataset /// \notes The generated dataset has two columns ["image", "label"] /// \param[in] dataset_dir Path to the root directory that contains the dataset -/// \param[in] usage of CIFAR10, can be "train", "test" or "all" +/// \param[in] usage of CIFAR10, can be "train", "test" or "all" (default = "all"). /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// \return Shared pointer to the current Dataset -std::shared_ptr Cifar10(const std::string &dataset_dir, const std::string &usage = std::string(), +std::shared_ptr Cifar10(const std::string &dataset_dir, const std::string &usage = "all", const std::shared_ptr &sampler = RandomSampler()); /// \brief Function to create a Cifar100 Dataset /// \notes The generated dataset has three columns ["image", "coarse_label", "fine_label"] /// \param[in] dataset_dir Path to the root directory that contains the dataset -/// \param[in] usage of CIFAR100, can be "train", "test" or "all" +/// \param[in] usage of CIFAR100, can be "train", "test" or "all" (default = "all"). /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// \return Shared pointer to the current Dataset -std::shared_ptr Cifar100(const std::string &dataset_dir, const std::string &usage = std::string(), +std::shared_ptr Cifar100(const std::string &dataset_dir, const std::string &usage = "all", const std::shared_ptr &sampler = RandomSampler()); /// \brief Function to create a CLUEDataset @@ -247,11 +247,11 @@ std::shared_ptr Manifest(const std::string &dataset_file, const /// \brief Function to create a MnistDataset /// \notes The generated dataset has two columns ["image", "label"] /// \param[in] dataset_dir Path to the root directory that contains the dataset -/// \param[in] usage of MNIST, can be "train", "test" or "all" +/// \param[in] usage of MNIST, can be "train", "test" or "all" (default = "all"). /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// \return Shared pointer to the current MnistDataset -std::shared_ptr Mnist(const std::string &dataset_dir, const std::string &usage = std::string(), +std::shared_ptr Mnist(const std::string &dataset_dir, const std::string &usage = "all", const std::shared_ptr &sampler = RandomSampler()); /// \brief Function to create a ConcatDataset @@ -407,7 +407,7 @@ std::shared_ptr TFRecord(const std::vector &datase /// - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]]. /// \param[in] dataset_dir Path to the root directory that contains the dataset /// \param[in] task Set the task type of reading voc data, now only support "Segmentation" or "Detection" -/// \param[in] usage The type of data list text file to be read +/// \param[in] usage The type of data list text file to be read (default = "train"). /// \param[in] class_indexing A str-to-int mapping from label name to index, only valid in "Detection" task /// \param[in] decode Decode the images after reading /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, diff --git a/mindspore/ccsrc/minddata/dataset/include/samplers.h b/mindspore/ccsrc/minddata/dataset/include/samplers.h index 6504a01595..204db81119 100644 --- a/mindspore/ccsrc/minddata/dataset/include/samplers.h +++ b/mindspore/ccsrc/minddata/dataset/include/samplers.h @@ -52,12 +52,13 @@ class WeightedRandomSamplerObj; /// \param[in] shuffle - If true, the indices are shuffled. /// \param[in] num_samples - The number of samples to draw (default to all elements). /// \param[in] seed - The seed in use when shuffle is true. +/// \param[in] offset - The starting position where access to elements in the dataset begins. /// \param[in] even_dist - If true, each shard would return the same number of rows (default to true). /// If false the total rows returned by all the shards would not have overlap. /// \return Shared pointer to the current Sampler. std::shared_ptr DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0, uint32_t seed = 1, - bool even_dist = true); + int64_t offset = -1, bool even_dist = true); /// Function to create a PK Sampler. /// \notes Samples K elements for each P class in the dataset. @@ -103,7 +104,7 @@ std::shared_ptr WeightedRandomSampler(std::vector &words, const std::vector &special_tokens, bool prepend_special, std::shared_ptr *vocab) { - // Validate parameters - std::string duplicate_word; - for (const WordType &word : words) { - if (std::count(words.begin(), words.end(), word) > 1) { - if (duplicate_word.find(word) == std::string::npos) { - duplicate_word = duplicate_word.empty() ? duplicate_word + word : duplicate_word + ", " + word; - } - } - } - if (!duplicate_word.empty()) { - MS_LOG(ERROR) << "words contains duplicate word: " << duplicate_word; - RETURN_STATUS_UNEXPECTED("words contains duplicate word: " + duplicate_word); - } - - std::string duplicate_sp; - std::string existed_sp; - for (const WordType &sp : special_tokens) { - if (std::count(special_tokens.begin(), special_tokens.end(), sp) > 1) { - if (duplicate_sp.find(sp) == std::string::npos) { - duplicate_sp = duplicate_sp.empty() ? duplicate_sp + sp : duplicate_sp + ", " + sp; - } - } - if (std::count(words.begin(), words.end(), sp) >= 1) { - if (existed_sp.find(sp) == std::string::npos) { - existed_sp = existed_sp.empty() ? existed_sp + sp : existed_sp + ", " + sp; - } - } - } - if (!duplicate_sp.empty()) { - MS_LOG(ERROR) << "special_tokens contains duplicate word: " << duplicate_sp; - RETURN_STATUS_UNEXPECTED("special_tokens contains duplicate word: " + duplicate_sp); - } - if (!existed_sp.empty()) { - MS_LOG(ERROR) << "special_tokens and word_list contain duplicate word: " << existed_sp; - RETURN_STATUS_UNEXPECTED("special_tokens and word_list contain duplicate word: " + existed_sp); - } - std::unordered_map word2id; // if special is added in front, normal words id will start from number of special tokens WordIdType word_id = prepend_special ? static_cast(special_tokens.size()) : 0; for (auto word : words) { + if (word2id.find(word) != word2id.end()) { + MS_LOG(ERROR) << "word_list contains duplicate word: " + word + "."; + RETURN_STATUS_UNEXPECTED("word_list contains duplicate word: " + word + "."); + } word2id[word] = word_id++; } word_id = prepend_special ? 0 : word2id.size(); for (auto special_token : special_tokens) { + if (word2id.find(special_token) != word2id.end()) { + MS_LOG(ERROR) << "special_tokens and word_list contain duplicate word: " + special_token + "."; + RETURN_STATUS_UNEXPECTED("special_tokens and word_list contain duplicate word: " + special_token + "."); + } word2id[special_token] = word_id++; } diff --git a/tests/ut/cpp/dataset/build_vocab_test.cc b/tests/ut/cpp/dataset/build_vocab_test.cc index 86f7a9a377..b742a3c616 100644 --- a/tests/ut/cpp/dataset/build_vocab_test.cc +++ b/tests/ut/cpp/dataset/build_vocab_test.cc @@ -183,8 +183,8 @@ TEST_F(MindDataTestVocab, TestVocabFromVectorFail2) { TEST_F(MindDataTestVocab, TestVocabFromVectorFail3) { MS_LOG(INFO) << "Doing MindDataTestVocab-TestVocabFromVectorFail3."; // Build vocab from a vector - std::vector list = {"apple", "dog", "egg", "", ""}; - std::vector sp_tokens = {"", ""}; + std::vector list = {"apple", "dog", "egg", "", ""}; + std::vector sp_tokens = {"", ""}; std::shared_ptr vocab = std::make_shared(); // Expected failure: special tokens are already existed in word_list diff --git a/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc b/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc index c21e465ef9..b0905ca40f 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc @@ -28,7 +28,7 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) { // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Cifar10(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset @@ -62,7 +62,7 @@ TEST_F(MindDataTestPipeline, TestCifar100Dataset) { // Create a Cifar100 Dataset std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; - std::shared_ptr ds = Cifar100(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Cifar100(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset @@ -96,7 +96,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetFail1) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100DatasetFail1."; // Create a Cifar100 Dataset - std::shared_ptr ds = Cifar100("", std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Cifar100("", "all", RandomSampler(false, 10)); EXPECT_EQ(ds, nullptr); } @@ -104,7 +104,17 @@ TEST_F(MindDataTestPipeline, TestCifar10DatasetFail1) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10DatasetFail1."; // Create a Cifar10 Dataset - std::shared_ptr ds = Cifar10("", std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Cifar10("", "all", RandomSampler(false, 10)); + EXPECT_EQ(ds, nullptr); +} + +TEST_F(MindDataTestPipeline, TestCifar10DatasetWithInvalidUsage) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10DatasetWithNullSampler."; + + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds = Cifar10(folder_path, "validation"); + // Expect failure: validation is not a valid usage EXPECT_EQ(ds, nullptr); } @@ -113,7 +123,7 @@ TEST_F(MindDataTestPipeline, TestCifar10DatasetWithNullSampler) { // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds = Cifar10(folder_path, std::string(), nullptr); + std::shared_ptr ds = Cifar10(folder_path, "all", nullptr); // Expect failure: sampler can not be nullptr EXPECT_EQ(ds, nullptr); } @@ -123,7 +133,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetWithNullSampler) { // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; - std::shared_ptr ds = Cifar100(folder_path, std::string(), nullptr); + std::shared_ptr ds = Cifar100(folder_path, "all", nullptr); // Expect failure: sampler can not be nullptr EXPECT_EQ(ds, nullptr); } @@ -133,7 +143,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetWithWrongSampler) { // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; - std::shared_ptr ds = Cifar100(folder_path, std::string(), RandomSampler(false, -10)); + std::shared_ptr ds = Cifar100(folder_path, "all", RandomSampler(false, -10)); // Expect failure: sampler is not construnced correctly EXPECT_EQ(ds, nullptr); } diff --git a/tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc b/tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc index 84573cbb30..f5623a51f6 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc @@ -28,7 +28,7 @@ TEST_F(MindDataTestPipeline, TestIteratorEmptyColumn) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorEmptyColumn."; // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds = Cifar10(folder_path, std::string(), RandomSampler(false, 5)); + std::shared_ptr ds = Cifar10(folder_path, "all", RandomSampler(false, 5)); EXPECT_NE(ds, nullptr); // Create a Rename operation on ds @@ -64,7 +64,7 @@ TEST_F(MindDataTestPipeline, TestIteratorOneColumn) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorOneColumn."; // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, std::string(), RandomSampler(false, 4)); + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 4)); EXPECT_NE(ds, nullptr); // Create a Batch operation on ds @@ -103,7 +103,7 @@ TEST_F(MindDataTestPipeline, TestIteratorReOrder) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorReOrder."; // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds = Cifar10(folder_path, std::string(), SequentialSampler(false, 4)); + std::shared_ptr ds = Cifar10(folder_path, "all", SequentialSampler(false, 4)); EXPECT_NE(ds, nullptr); // Create a Take operation on ds @@ -186,7 +186,7 @@ TEST_F(MindDataTestPipeline, TestIteratorWrongColumn) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorOneColumn."; // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, std::string(), RandomSampler(false, 4)); + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 4)); EXPECT_NE(ds, nullptr); // Pass wrong column name diff --git a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc index 0a7208c413..f8b2b2a856 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc @@ -40,7 +40,7 @@ TEST_F(MindDataTestPipeline, TestBatchAndRepeat) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds @@ -82,7 +82,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess1) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a BucketBatchByLength operation on ds @@ -118,7 +118,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess2) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a BucketBatchByLength operation on ds @@ -156,7 +156,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail1) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a BucketBatchByLength operation on ds @@ -171,7 +171,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail2) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a BucketBatchByLength operation on ds @@ -186,7 +186,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail3) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a BucketBatchByLength operation on ds @@ -201,7 +201,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail4) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a BucketBatchByLength operation on ds @@ -216,7 +216,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail5) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a BucketBatchByLength operation on ds @@ -231,7 +231,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail6) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a BucketBatchByLength operation on ds ds = ds->BucketBatchByLength({"image"}, {1, 2}, {1, -2, 3}); @@ -245,7 +245,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail7) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a BucketBatchByLength operation on ds @@ -312,7 +312,7 @@ TEST_F(MindDataTestPipeline, TestConcatSuccess) { // Create a Cifar10 Dataset // Column names: {"image", "label"} folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds2 = Cifar10(folder_path, std::string(), RandomSampler(false, 9)); + std::shared_ptr ds2 = Cifar10(folder_path, "all", RandomSampler(false, 9)); EXPECT_NE(ds2, nullptr); // Create a Project operation on ds @@ -364,7 +364,7 @@ TEST_F(MindDataTestPipeline, TestConcatSuccess2) { // Create a Cifar10 Dataset // Column names: {"image", "label"} folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds2 = Cifar10(folder_path, std::string(), RandomSampler(false, 9)); + std::shared_ptr ds2 = Cifar10(folder_path, "all", RandomSampler(false, 9)); EXPECT_NE(ds2, nullptr); // Create a Project operation on ds @@ -1012,7 +1012,7 @@ TEST_F(MindDataTestPipeline, TestTensorOpsAndMap) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, std::string(), RandomSampler(false, 20)); + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 20)); EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds @@ -1126,7 +1126,7 @@ TEST_F(MindDataTestPipeline, TestZipSuccess) { EXPECT_NE(ds1, nullptr); folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds2 = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds2 = Cifar10(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds2, nullptr); // Create a Project operation on ds diff --git a/tests/ut/cpp/dataset/c_api_dataset_vocab.cc b/tests/ut/cpp/dataset/c_api_dataset_vocab.cc index 11926e5bb8..872a56d309 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_vocab.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_vocab.cc @@ -80,6 +80,50 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOp) { } } +TEST_F(MindDataTestPipeline, TestVocabLookupOpEmptyString) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabLookupOpEmptyString."; + + // Create a TextFile dataset + std::string data_file = datasets_root_path_ + "/testVocab/words.txt"; + std::shared_ptr ds = TextFile({data_file}, 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create a vocab from vector + std::vector list = {"home", "IS", "behind", "the", "world", "ahead", "!"}; + std::shared_ptr vocab = std::make_shared(); + Status s = Vocab::BuildFromVector(list, {"", ""}, true, &vocab); + EXPECT_EQ(s, Status::OK()); + + // Create Lookup operation on ds + std::shared_ptr lookup = text::Lookup(vocab, "", DataType("int32")); + EXPECT_NE(lookup, nullptr); + + // Create Map operation on ds + ds = ds->Map({lookup}, {"text"}); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + std::vector expected = {2, 1, 4, 5, 6, 7}; + while (row.size() != 0) { + auto ind = row["text"]; + MS_LOG(INFO) << ind->shape() << " " << *ind; + std::shared_ptr expected_item; + Tensor::CreateScalar(expected[i], &expected_item); + EXPECT_EQ(*ind, *expected_item); + iter->GetNextRow(&row); + i++; + } +} + TEST_F(MindDataTestPipeline, TestVocabLookupOpFail1) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabLookupOpFail1."; // Create a TextFile Dataset @@ -110,27 +154,6 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail2) { EXPECT_EQ(lookup, nullptr); } -TEST_F(MindDataTestPipeline, TestVocabLookupOpWithEmptyUnknownToken) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabLookupOpWithEmptyUnknownToken."; - - // Create a TextFile dataset - std::string data_file = datasets_root_path_ + "/testVocab/words.txt"; - std::shared_ptr ds = TextFile({data_file}, 0, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Create a vocab from map - std::unordered_map dict; - dict["Home"] = 3; - std::shared_ptr vocab = std::make_shared(); - Status s = Vocab::BuildFromUnorderedMap(dict, &vocab); - EXPECT_EQ(s, Status::OK()); - - // Create Lookup operation on ds - // Expected failure: "" is not a word of vocab - std::shared_ptr lookup = text::Lookup(vocab, "", DataType("int32")); - EXPECT_EQ(lookup, nullptr); -} - TEST_F(MindDataTestPipeline, TestVocabFromDataset) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabFromDataset."; diff --git a/tests/ut/cpp/dataset/c_api_datasets_test.cc b/tests/ut/cpp/dataset/c_api_datasets_test.cc index b8040f6e40..ad950f8334 100644 --- a/tests/ut/cpp/dataset/c_api_datasets_test.cc +++ b/tests/ut/cpp/dataset/c_api_datasets_test.cc @@ -133,7 +133,7 @@ TEST_F(MindDataTestPipeline, TestMnistFailWithWrongDatasetDir) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMnistFailWithWrongDatasetDir."; // Create a Mnist Dataset - std::shared_ptr ds = Mnist("", std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Mnist("", "all", RandomSampler(false, 10)); EXPECT_EQ(ds, nullptr); } @@ -142,7 +142,7 @@ TEST_F(MindDataTestPipeline, TestMnistFailWithNullSampler) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, std::string(), nullptr); + std::shared_ptr ds = Mnist(folder_path, "all", nullptr); // Expect failure: sampler can not be nullptr EXPECT_EQ(ds, nullptr); } diff --git a/tests/ut/cpp/dataset/c_api_transforms_test.cc b/tests/ut/cpp/dataset/c_api_transforms_test.cc index 16bcb39887..0df131076f 100644 --- a/tests/ut/cpp/dataset/c_api_transforms_test.cc +++ b/tests/ut/cpp/dataset/c_api_transforms_test.cc @@ -30,7 +30,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) { // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; int number_of_classes = 10; - std::shared_ptr ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Cifar10(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create objects for the tensor ops @@ -98,7 +98,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) { // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; int number_of_classes = 10; - std::shared_ptr ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Cifar10(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a Batch operation on ds @@ -156,7 +156,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail1) { // Must fail because alpha can't be negative // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Cifar10(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a Batch operation on ds @@ -181,7 +181,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) { // Must fail because prob can't be negative // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Cifar10(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a Batch operation on ds @@ -206,7 +206,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail3) { // Must fail because alpha can't be zero // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Cifar10(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a Batch operation on ds @@ -376,7 +376,7 @@ TEST_F(MindDataTestPipeline, TestHwcToChw) { TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) { // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Cifar10(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a Batch operation on ds @@ -400,7 +400,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail2) { // This should fail because alpha can't be zero // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Cifar10(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a Batch operation on ds @@ -423,7 +423,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail2) { TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) { // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Cifar10(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a Batch operation on ds @@ -472,7 +472,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) { TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess2) { // Create a Cifar10 Dataset std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - std::shared_ptr ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); + std::shared_ptr ds = Cifar10(folder_path, "all", RandomSampler(false, 10)); EXPECT_NE(ds, nullptr); // Create a Batch operation on ds @@ -1118,7 +1118,7 @@ TEST_F(MindDataTestPipeline, TestRandomRotation) { TEST_F(MindDataTestPipeline, TestUniformAugWithOps) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, "", RandomSampler(false, 20)); + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 20)); EXPECT_NE(ds, nullptr); // Create a Repeat operation on ds