forked from mindspore-Ecosystem/mindspore
C++ API: Minor fixes for dataset parameters
This commit is contained in:
parent
3eef4a4e06
commit
93810a0dc8
|
@ -218,7 +218,7 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
|
|||
}
|
||||
|
||||
// Function to create a TextFileDataset.
|
||||
std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int32_t num_samples,
|
||||
std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) {
|
||||
auto ds = std::make_shared<TextFileDataset>(dataset_files, num_samples, shuffle, num_shards, shard_id);
|
||||
|
||||
|
@ -1331,7 +1331,7 @@ bool TextFileDataset::ValidateParams() {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!ValidateDatasetShardParams("TextfileDataset", num_shards_, shard_id_)) {
|
||||
if (!ValidateDatasetShardParams("TextFileDataset", num_shards_, shard_id_)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -84,10 +84,10 @@ std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "");
|
|||
// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||
/// \param[in] dataset_type One of 'all', 'train', 'valid' or 'test'.
|
||||
/// \param[in] decode Decode the images after reading (default=False).
|
||||
/// \param[in] extensions List of file extensions to be included in the dataset (default=None).
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
|
||||
/// will be used to randomly iterate the entire dataset
|
||||
/// \param[in] decode Decode the images after reading (default=false).
|
||||
/// \param[in] extensions Set of file extensions to be included in the dataset (default={}).
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type = "all",
|
||||
const std::shared_ptr<SamplerObj> &sampler = nullptr, bool decode = false,
|
||||
|
@ -199,11 +199,11 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir,
|
|||
/// \notes The generated dataset has two columns ['image', 'label']
|
||||
/// \param[in] dataset_file The dataset file to be read
|
||||
/// \param[in] usage Need "train", "eval" or "inference" data (default="train")
|
||||
/// \param[in] decode Decode the images after reading (default=false).
|
||||
/// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder
|
||||
/// names will be sorted alphabetically and each class will be given a unique index starting from 0).
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
|
||||
/// A `RandomSampler` will be used to randomly iterate the entire dataset
|
||||
/// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder
|
||||
/// names will be sorted alphabetically and each class will be given a unique index starting from 0).
|
||||
/// \param[in] decode Decode the images after reading (default=false).
|
||||
/// \return Shared pointer to the current ManifestDataset
|
||||
std::shared_ptr<ManifestDataset> Manifest(std::string dataset_file, std::string usage = "train",
|
||||
std::shared_ptr<SamplerObj> sampler = nullptr,
|
||||
|
@ -230,13 +230,13 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
|
|||
/// \brief Function to create a RandomDataset
|
||||
/// \param[in] total_rows Number of rows for the dataset to generate (default=0, number of rows is random)
|
||||
/// \param[in] schema SchemaObj to set column type, data type and data shape
|
||||
/// \param[in] columns_list List of columns to be read (default=None, read all columns)
|
||||
/// \param[in] columns_list List of columns to be read (default={}, read all columns)
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
|
||||
/// will be used to randomly iterate the entire dataset
|
||||
/// \return Shared pointer to the current Dataset
|
||||
template <typename T = std::shared_ptr<SchemaObj>>
|
||||
std::shared_ptr<RandomDataset> RandomData(const int32_t &total_rows = 0, T schema = nullptr,
|
||||
std::vector<std::string> columns_list = {},
|
||||
const std::vector<std::string> &columns_list = {},
|
||||
std::shared_ptr<SamplerObj> sampler = nullptr) {
|
||||
auto ds = std::make_shared<RandomDataset>(total_rows, schema, std::move(columns_list), std::move(sampler));
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
|
@ -257,7 +257,7 @@ std::shared_ptr<RandomDataset> RandomData(const int32_t &total_rows = 0, T schem
|
|||
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||
/// specified only when num_shards is also specified. (Default = 0)
|
||||
/// \return Shared pointer to the current TextFileDataset
|
||||
std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int32_t num_samples = 0,
|
||||
std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples = 0,
|
||||
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
|
||||
int32_t shard_id = 0);
|
||||
|
||||
|
@ -302,7 +302,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0;
|
||||
|
||||
/// \brief Pure virtual function for derived class to implement parameters validation
|
||||
/// \return bool True if all the params are valid
|
||||
/// \return bool true if all the parameters are valid
|
||||
virtual bool ValidateParams() = 0;
|
||||
|
||||
/// \brief Setter function for runtime number of workers
|
||||
|
@ -767,8 +767,8 @@ class RandomDataset : public Dataset {
|
|||
static constexpr int32_t kMaxDimValue = 32;
|
||||
|
||||
/// \brief Constructor
|
||||
RandomDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, std::vector<std::string> columns_list,
|
||||
std::shared_ptr<SamplerObj> sampler)
|
||||
RandomDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
|
||||
const std::vector<std::string> &columns_list, std::shared_ptr<SamplerObj> sampler)
|
||||
: total_rows_(total_rows),
|
||||
schema_path_(""),
|
||||
schema_(std::move(schema)),
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_SAMPLERS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_SAMPLERS_H_
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
@ -70,7 +70,7 @@ std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle = false, i
|
|||
|
||||
/// Function to create a Random Sampler.
|
||||
/// \notes Samples the elements randomly.
|
||||
/// \param[in] replacement - If True, put the sample ID back for the next draw.
|
||||
/// \param[in] replacement - If true, put the sample ID back for the next draw.
|
||||
/// \param[in] num_samples - The number of samples to draw (default to all elements).
|
||||
/// \return Shared pointer to the current Sampler.
|
||||
std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement = false, int64_t num_samples = 0);
|
||||
|
@ -94,7 +94,7 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t>
|
|||
/// weights (probabilities).
|
||||
/// \param[in] weights - A vector sequence of weights, not necessarily summing up to 1.
|
||||
/// \param[in] num_samples - The number of samples to draw (default to all elements).
|
||||
/// \param[in] replacement - If True, put the sample ID back for the next draw.
|
||||
/// \param[in] replacement - If true, put the sample ID back for the next draw.
|
||||
/// \return Shared pointer to the current Sampler.
|
||||
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0,
|
||||
bool replacement = true);
|
||||
|
@ -199,4 +199,4 @@ class WeightedRandomSamplerObj : public SamplerObj {
|
|||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_SAMPLERS_H_
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
|
||||
|
|
Loading…
Reference in New Issue