forked from mindspore-Ecosystem/mindspore
!3651 change num_samples definition
Merge pull request !3651 from jiangzhiwen/dataset/change_num_samples
This commit is contained in:
commit
387dac5832
|
@ -27,7 +27,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
CsvOp::Builder::Builder()
|
CsvOp::Builder::Builder()
|
||||||
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
|
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(-1), builder_shuffle_files_(false) {
|
||||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||||
builder_op_connector_size_ = config_manager->op_connector_size();
|
builder_op_connector_size_ = config_manager->op_connector_size();
|
||||||
|
@ -451,7 +451,7 @@ Status CsvOp::operator()() {
|
||||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
|
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
|
||||||
if (buffer->eoe()) {
|
if (buffer->eoe()) {
|
||||||
workers_done++;
|
workers_done++;
|
||||||
} else if (num_samples_ == 0 || rows_read < num_samples_) {
|
} else if (num_samples_ == -1 || rows_read < num_samples_) {
|
||||||
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
|
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
|
||||||
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
|
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
|
||||||
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
|
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
|
||||||
|
|
|
@ -4935,7 +4935,7 @@ class CSVDataset(SourceDataset):
|
||||||
columns as string type.
|
columns as string type.
|
||||||
column_names (list[str], optional): List of column names of the dataset (default=None). If this
|
column_names (list[str], optional): List of column names of the dataset (default=None). If this
|
||||||
is not provided, infers the column_names from the first row of CSV file.
|
is not provided, infers the column_names from the first row of CSV file.
|
||||||
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
|
num_samples (int, optional): number of samples(rows) to read (default=-1, reads the full dataset).
|
||||||
num_parallel_workers (int, optional): number of workers to read the data
|
num_parallel_workers (int, optional): number of workers to read the data
|
||||||
(default=None, number set in the config).
|
(default=None, number set in the config).
|
||||||
shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch
|
shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch
|
||||||
|
@ -4959,7 +4959,7 @@ class CSVDataset(SourceDataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@check_csvdataset
|
@check_csvdataset
|
||||||
def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None,
|
def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=-1,
|
||||||
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
|
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
|
||||||
super().__init__(num_parallel_workers)
|
super().__init__(num_parallel_workers)
|
||||||
self.dataset_files = self._find_files(dataset_files)
|
self.dataset_files = self._find_files(dataset_files)
|
||||||
|
@ -5010,7 +5010,7 @@ class CSVDataset(SourceDataset):
|
||||||
if self._dataset_size is None:
|
if self._dataset_size is None:
|
||||||
num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None)
|
num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None)
|
||||||
num_rows = get_num_rows(num_rows, self.num_shards)
|
num_rows = get_num_rows(num_rows, self.num_shards)
|
||||||
if self.num_samples is None:
|
if self.num_samples == -1:
|
||||||
return num_rows
|
return num_rows
|
||||||
return min(self.num_samples, num_rows)
|
return min(self.num_samples, num_rows)
|
||||||
return self._dataset_size
|
return self._dataset_size
|
||||||
|
|
|
@ -815,12 +815,16 @@ def check_csvdataset(method):
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||||
|
|
||||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
nreq_param_int = ['num_parallel_workers', 'num_shards', 'shard_id']
|
||||||
|
|
||||||
# check dataset_files; required argument
|
# check dataset_files; required argument
|
||||||
dataset_files = param_dict.get('dataset_files')
|
dataset_files = param_dict.get('dataset_files')
|
||||||
type_check(dataset_files, (str, list), "dataset files")
|
type_check(dataset_files, (str, list), "dataset files")
|
||||||
|
|
||||||
|
# check num_samples
|
||||||
|
num_samples = param_dict.get('num_samples')
|
||||||
|
check_value(num_samples, [-1, INT32_MAX], "num_samples")
|
||||||
|
|
||||||
# check field_delim
|
# check field_delim
|
||||||
field_delim = param_dict.get('field_delim')
|
field_delim = param_dict.get('field_delim')
|
||||||
type_check(field_delim, (str,), 'field delim')
|
type_check(field_delim, (str,), 'field delim')
|
||||||
|
|
Loading…
Reference in New Issue