From 93810a0dc84e44ff8905428e32654aaa2d78e2ce Mon Sep 17 00:00:00 2001 From: Cathy Wong Date: Fri, 21 Aug 2020 15:57:44 -0400 Subject: [PATCH] C++ API: Minor fixes for dataset parameters --- .../ccsrc/minddata/dataset/api/datasets.cc | 4 ++-- .../ccsrc/minddata/dataset/include/datasets.h | 22 +++++++++---------- .../ccsrc/minddata/dataset/include/samplers.h | 10 ++++----- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 97ce2820dae..8ace68d2a3c 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -218,7 +218,7 @@ std::shared_ptr operator+(const std::shared_ptr &dataset } // Function to create a TextFileDataset. -std::shared_ptr TextFile(const std::vector &dataset_files, int32_t num_samples, +std::shared_ptr TextFile(const std::vector &dataset_files, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) { auto ds = std::make_shared(dataset_files, num_samples, shuffle, num_shards, shard_id); @@ -1331,7 +1331,7 @@ bool TextFileDataset::ValidateParams() { return false; } - if (!ValidateDatasetShardParams("TextfileDataset", num_shards_, shard_id_)) { + if (!ValidateDatasetShardParams("TextFileDataset", num_shards_, shard_id_)) { return false; } diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 445c3690b16..ea3f65a5ed9 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -84,10 +84,10 @@ std::shared_ptr Schema(const std::string &schema_file = ""); // The type of the image tensor is uint8. The attr tensor is uint32 and one hot type. /// \param[in] dataset_dir Path to the root directory that contains the dataset. /// \param[in] dataset_type One of 'all', 'train', 'valid' or 'test'. -/// \param[in] decode Decode the images after reading (default=False). -/// \param[in] extensions List of file extensions to be included in the dataset (default=None). /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` /// will be used to randomly iterate the entire dataset +/// \param[in] decode Decode the images after reading (default=false). +/// \param[in] extensions Set of file extensions to be included in the dataset (default={}). /// \return Shared pointer to the current Dataset std::shared_ptr CelebA(const std::string &dataset_dir, const std::string &dataset_type = "all", const std::shared_ptr &sampler = nullptr, bool decode = false, @@ -199,11 +199,11 @@ std::shared_ptr ImageFolder(const std::string &dataset_dir, /// \notes The generated dataset has two columns ['image', 'label'] /// \param[in] dataset_file The dataset file to be read /// \param[in] usage Need "train", "eval" or "inference" data (default="train") -/// \param[in] decode Decode the images after reading (default=false). -/// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder -/// names will be sorted alphabetically and each class will be given a unique index starting from 0). /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, /// A `RandomSampler` will be used to randomly iterate the entire dataset +/// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder +/// names will be sorted alphabetically and each class will be given a unique index starting from 0). +/// \param[in] decode Decode the images after reading (default=false). /// \return Shared pointer to the current ManifestDataset std::shared_ptr Manifest(std::string dataset_file, std::string usage = "train", std::shared_ptr sampler = nullptr, @@ -230,13 +230,13 @@ std::shared_ptr operator+(const std::shared_ptr &dataset /// \brief Function to create a RandomDataset /// \param[in] total_rows Number of rows for the dataset to generate (default=0, number of rows is random) /// \param[in] schema SchemaObj to set column type, data type and data shape -/// \param[in] columns_list List of columns to be read (default=None, read all columns) +/// \param[in] columns_list List of columns to be read (default={}, read all columns) /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` /// will be used to randomly iterate the entire dataset /// \return Shared pointer to the current Dataset template > std::shared_ptr RandomData(const int32_t &total_rows = 0, T schema = nullptr, - std::vector columns_list = {}, + const std::vector &columns_list = {}, std::shared_ptr sampler = nullptr) { auto ds = std::make_shared(total_rows, schema, std::move(columns_list), std::move(sampler)); return ds->ValidateParams() ? ds : nullptr; @@ -257,7 +257,7 @@ std::shared_ptr RandomData(const int32_t &total_rows = 0, T schem /// \param[in] shard_id The shard ID within num_shards. This argument should be /// specified only when num_shards is also specified. (Default = 0) /// \return Shared pointer to the current TextFileDataset -std::shared_ptr TextFile(const std::vector &dataset_files, int32_t num_samples = 0, +std::shared_ptr TextFile(const std::vector &dataset_files, int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0); @@ -302,7 +302,7 @@ class Dataset : public std::enable_shared_from_this { virtual std::vector> Build() = 0; /// \brief Pure virtual function for derived class to implement parameters validation - /// \return bool True if all the params are valid + /// \return bool true if all the parameters are valid virtual bool ValidateParams() = 0; /// \brief Setter function for runtime number of workers @@ -767,8 +767,8 @@ class RandomDataset : public Dataset { static constexpr int32_t kMaxDimValue = 32; /// \brief Constructor - RandomDataset(const int32_t &total_rows, std::shared_ptr schema, std::vector columns_list, - std::shared_ptr sampler) + RandomDataset(const int32_t &total_rows, std::shared_ptr schema, + const std::vector &columns_list, std::shared_ptr sampler) : total_rows_(total_rows), schema_path_(""), schema_(std::move(schema)), diff --git a/mindspore/ccsrc/minddata/dataset/include/samplers.h b/mindspore/ccsrc/minddata/dataset/include/samplers.h index 9d423c78fa0..6504a01595c 100644 --- a/mindspore/ccsrc/minddata/dataset/include/samplers.h +++ b/mindspore/ccsrc/minddata/dataset/include/samplers.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_SAMPLERS_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_SAMPLERS_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_ #include #include @@ -70,7 +70,7 @@ std::shared_ptr PKSampler(int64_t num_val, bool shuffle = false, i /// Function to create a Random Sampler. /// \notes Samples the elements randomly. -/// \param[in] replacement - If True, put the sample ID back for the next draw. +/// \param[in] replacement - If true, put the sample ID back for the next draw. /// \param[in] num_samples - The number of samples to draw (default to all elements). /// \return Shared pointer to the current Sampler. std::shared_ptr RandomSampler(bool replacement = false, int64_t num_samples = 0); @@ -94,7 +94,7 @@ std::shared_ptr SubsetRandomSampler(std::vector /// weights (probabilities). /// \param[in] weights - A vector sequence of weights, not necessarily summing up to 1. /// \param[in] num_samples - The number of samples to draw (default to all elements). -/// \param[in] replacement - If True, put the sample ID back for the next draw. +/// \param[in] replacement - If true, put the sample ID back for the next draw. /// \return Shared pointer to the current Sampler. std::shared_ptr WeightedRandomSampler(std::vector weights, int64_t num_samples = 0, bool replacement = true); @@ -199,4 +199,4 @@ class WeightedRandomSamplerObj : public SamplerObj { } // namespace api } // namespace dataset } // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_SAMPLERS_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_