diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 1b8aece8421..019415bbc71 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -1200,7 +1200,7 @@ bool CSVDataset::ValidateParams() { return false; } - if (num_samples_ < -1) { + if (num_samples_ < 0) { MS_LOG(ERROR) << "CSVDataset: Invalid number of samples: " << num_samples_; return false; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc index ff0285442de..6cff81c1b81 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace dataset { CsvOp::Builder::Builder() - : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(-1), builder_shuffle_files_(false) { + : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_op_connector_size_ = config_manager->op_connector_size(); @@ -539,7 +539,7 @@ Status CsvOp::operator()() { RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); if (buffer->eoe()) { workers_done++; - } else if (num_samples_ == -1 || rows_read < num_samples_) { + } else if (num_samples_ == 0 || rows_read < num_samples_) { if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index de507ac8ba4..b8ebdc0c9d8 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -191,7 +191,7 @@ std::shared_ptr Coco(const std::string &dataset_dir, const std::str /// \param[in] column_names List of column names of the dataset (default={}). If this is not provided, infers the /// column_names from the first row of CSV file. /// \param[in] num_samples The number of samples to be included in the dataset. -/// (Default = -1 means all samples.) +/// (Default = 0 means all samples.) /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode::kGlobal) /// Can be any of: /// ShuffleMode::kFalse - No shuffling is performed. @@ -203,7 +203,7 @@ std::shared_ptr Coco(const std::string &dataset_dir, const std::str /// \return Shared pointer to the current Dataset std::shared_ptr CSV(const std::vector &dataset_files, char field_delim = ',', const std::vector> &column_defaults = {}, - const std::vector &column_names = {}, int64_t num_samples = -1, + const std::vector &column_names = {}, int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0); diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 7eb77480eae..7c8e6bf9215 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -5140,7 +5140,7 @@ class CSVDataset(SourceDataset): columns as string type. column_names (list[str], optional): List of column names of the dataset (default=None). If this is not provided, infers the column_names from the first row of CSV file. - num_samples (int, optional): number of samples(rows) to read (default=-1, reads the full dataset). + num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). num_parallel_workers (int, optional): number of workers to read the data (default=None, number set in the config). shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch @@ -5164,7 +5164,7 @@ class CSVDataset(SourceDataset): """ @check_csvdataset - def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=-1, + def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): super().__init__(num_parallel_workers) self.dataset_files = self._find_files(dataset_files) @@ -5215,7 +5215,7 @@ class CSVDataset(SourceDataset): if self.dataset_size is None: num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None) self.dataset_size = get_num_rows(num_rows, self.num_shards) - if self.num_samples != -1 and self.num_samples < self.dataset_size: + if self.num_samples is not None and self.num_samples < self.dataset_size: self.dataset_size = num_rows return self.dataset_size diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 4f53e6aabe9..4aae8f2c247 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -830,16 +830,12 @@ def check_csvdataset(method): def new_method(self, *args, **kwargs): _, param_dict = parse_user_args(method, *args, **kwargs) - nreq_param_int = ['num_parallel_workers', 'num_shards', 'shard_id'] + nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] # check dataset_files; required argument dataset_files = param_dict.get('dataset_files') type_check(dataset_files, (str, list), "dataset files") - # check num_samples - num_samples = param_dict.get('num_samples') - check_value(num_samples, [-1, INT32_MAX], "num_samples") - # check field_delim field_delim = param_dict.get('field_delim') type_check(field_delim, (str,), 'field delim') diff --git a/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc b/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc index 52aee05be65..28f862ded1a 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc @@ -33,7 +33,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetBasic) { // Create a CSVDataset, with single CSV file std::string train_file = datasets_root_path_ + "/testCSV/1.csv"; std::vector column_names = {"col1", "col2", "col3", "col4"}; - std::shared_ptr ds = CSV({train_file}, ',', {}, column_names, -1, ShuffleMode::kFalse); + std::shared_ptr ds = CSV({train_file}, ',', {}, column_names, 0, ShuffleMode::kFalse); EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset @@ -85,7 +85,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetMultiFiles) { std::string file1 = datasets_root_path_ + "/testCSV/1.csv"; std::string file2 = datasets_root_path_ + "/testCSV/append.csv"; std::vector column_names = {"col1", "col2", "col3", "col4"}; - std::shared_ptr ds = CSV({file1, file2}, ',', {}, column_names, -1, ShuffleMode::kGlobal); + std::shared_ptr ds = CSV({file1, file2}, ',', {}, column_names, 0, ShuffleMode::kGlobal); EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset @@ -179,7 +179,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetDistribution) { // Create a CSVDataset, with single CSV file std::string file = datasets_root_path_ + "/testCSV/1.csv"; std::vector column_names = {"col1", "col2", "col3", "col4"}; - std::shared_ptr ds = CSV({file}, ',', {}, column_names, -1, ShuffleMode::kFalse, 2, 0); + std::shared_ptr ds = CSV({file}, ',', {}, column_names, 0, ShuffleMode::kFalse, 2, 0); EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset @@ -228,7 +228,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetType) { std::make_shared>(CsvType::STRING, ""), }; std::vector column_names = {"col1", "col2", "col3", "col4"}; - std::shared_ptr ds = CSV({file}, ',', colum_type, column_names, -1, ShuffleMode::kFalse); + std::shared_ptr ds = CSV({file}, ',', colum_type, column_names, 0, ShuffleMode::kFalse); EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset @@ -343,15 +343,15 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetException) { EXPECT_EQ(ds1, nullptr); // Test invalid num_samples < -1 - std::shared_ptr ds2 = CSV({file}, ',', {}, column_names, -2); + std::shared_ptr ds2 = CSV({file}, ',', {}, column_names, -1); EXPECT_EQ(ds2, nullptr); // Test invalid num_shards < 1 - std::shared_ptr ds3 = CSV({file}, ',', {}, column_names, -1, ShuffleMode::kFalse, 0); + std::shared_ptr ds3 = CSV({file}, ',', {}, column_names, 0, ShuffleMode::kFalse, 0); EXPECT_EQ(ds3, nullptr); // Test invalid shard_id >= num_shards - std::shared_ptr ds4 = CSV({file}, ',', {}, column_names, -1, ShuffleMode::kFalse, 2, 2); + std::shared_ptr ds4 = CSV({file}, ',', {}, column_names, 0, ShuffleMode::kFalse, 2, 2); EXPECT_EQ(ds4, nullptr); // Test invalid field_delim @@ -373,7 +373,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFilesA) { std::string file1 = datasets_root_path_ + "/testCSV/1.csv"; std::string file2 = datasets_root_path_ + "/testCSV/append.csv"; std::vector column_names = {"col1", "col2", "col3", "col4"}; - std::shared_ptr ds = CSV({file1, file2}, ',', {}, column_names, -1, ShuffleMode::kFiles); + std::shared_ptr ds = CSV({file1, file2}, ',', {}, column_names, 0, ShuffleMode::kFiles); EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset @@ -432,7 +432,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFilesB) { std::string file1 = datasets_root_path_ + "/testCSV/1.csv"; std::string file2 = datasets_root_path_ + "/testCSV/append.csv"; std::vector column_names = {"col1", "col2", "col3", "col4"}; - std::shared_ptr ds = CSV({file2, file1}, ',', {}, column_names, -1, ShuffleMode::kFiles); + std::shared_ptr ds = CSV({file2, file1}, ',', {}, column_names, 0, ShuffleMode::kFiles); EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset @@ -492,7 +492,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleGlobal) { // Create a CSVFile Dataset, with single CSV file std::string train_file = datasets_root_path_ + "/testCSV/1.csv"; std::vector column_names = {"col1", "col2", "col3", "col4"}; - std::shared_ptr ds = CSV({train_file}, ',', {}, column_names, -1, ShuffleMode::kGlobal); + std::shared_ptr ds = CSV({train_file}, ',', {}, column_names, 0, ShuffleMode::kGlobal); EXPECT_NE(ds, nullptr); // Create an iterator over the result of the above dataset @@ -540,7 +540,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetDuplicateColumnName) { // Create a CSVDataset, with single CSV file std::string train_file = datasets_root_path_ + "/testCSV/1.csv"; std::vector column_names = {"col1", "col1", "col3", "col4"}; - std::shared_ptr ds = CSV({train_file}, ',', {}, column_names, -1, ShuffleMode::kFalse); + std::shared_ptr ds = CSV({train_file}, ',', {}, column_names, 0, ShuffleMode::kFalse); // Expect failure: duplicate column names EXPECT_EQ(ds, nullptr); }