forked from mindspore-Ecosystem/mindspore
add GetDatasetSize to concat node
This commit is contained in:
parent
ddf84551ab
commit
b5e91139e2
|
@ -402,7 +402,7 @@ std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::s
|
|||
// Function to overload "+" operator to concat two datasets
|
||||
std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
|
||||
const std::shared_ptr<Dataset> &datasets2) {
|
||||
return std::make_shared<ConcatDataset>(std::vector({datasets2, datasets1}));
|
||||
return std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2}));
|
||||
}
|
||||
|
||||
// Function to create a TextFileDataset.
|
||||
|
|
|
@ -73,6 +73,51 @@ Status ConcatNode::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status ConcatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// calculate the total size of all nodes
|
||||
int64_t total_dataset_size = 0;
|
||||
int64_t child_dataset_size = 0;
|
||||
for (int idx = 0; idx < children_.size(); idx++) {
|
||||
if (children_flag_and_nums_.empty() || children_flag_and_nums_[idx].second == 0) {
|
||||
children_[idx]->GetDatasetSize(size_getter, false, &child_dataset_size);
|
||||
total_dataset_size += child_dataset_size;
|
||||
} else {
|
||||
total_dataset_size += children_flag_and_nums_[idx].second;
|
||||
}
|
||||
}
|
||||
|
||||
// calculate the size of the shard
|
||||
int64_t shard_dataset_size = 0;
|
||||
if (sampler_ != nullptr) {
|
||||
std::shared_ptr<DistributedSamplerRT> sampler_rt =
|
||||
std::static_pointer_cast<DistributedSamplerRT>(sampler_->SamplerBuild());
|
||||
sampler_rt->SetNumRowsInDataset(total_dataset_size);
|
||||
sampler_rt->InitSampler();
|
||||
|
||||
// (total_size % num_shards != 0) & shard_id >= (remainder) ? CalculateNumSamples()-1 : CalculateNumSamples()
|
||||
// example: 23 rows, 10 shards --> shard sizes = {3,3,3,2,2,2,2,2,2,2}
|
||||
if ((sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum()) > 0 &&
|
||||
sampler_rt->GetDeviceID() >= (sampler_rt->GetNumSamples() % sampler_rt->GetDeviceNum())) {
|
||||
shard_dataset_size = sampler_rt->CalculateNumSamples(sampler_rt->GetNumSamples()) - 1;
|
||||
} else {
|
||||
shard_dataset_size = sampler_rt->CalculateNumSamples(sampler_rt->GetNumSamples());
|
||||
}
|
||||
} else {
|
||||
shard_dataset_size = total_dataset_size;
|
||||
}
|
||||
|
||||
*dataset_size = shard_dataset_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) {
|
||||
node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_));
|
||||
|
|
|
@ -55,6 +55,15 @@ class ConcatNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if build successfully
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size the size of the dataset
|
||||
/// \return Status of the function
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
|
|
@ -261,8 +261,8 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// \param[in] datasets List of shared pointers to the dataset that should be concatenated together
|
||||
/// \return Shared pointer to the current ConcatDataset
|
||||
std::shared_ptr<ConcatDataset> Concat(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
||||
std::vector<std::shared_ptr<Dataset>> all_datasets = datasets;
|
||||
all_datasets.push_back(shared_from_this());
|
||||
std::vector<std::shared_ptr<Dataset>> all_datasets{shared_from_this()};
|
||||
all_datasets.insert(std::end(all_datasets), std::begin(datasets), std::end(datasets));
|
||||
return std::make_shared<ConcatDataset>(all_datasets);
|
||||
}
|
||||
|
||||
|
|
|
@ -372,7 +372,7 @@ class Dataset:
|
|||
Args:
|
||||
condition_name (str): The condition name that is used to toggle sending next row.
|
||||
num_batch (int): the number of batches without blocking at the start of each epoch.
|
||||
callback (function): The callback funciton that will be invoked when sync_update is called.
|
||||
callback (function): The callback function that will be invoked when sync_update is called.
|
||||
|
||||
Returns:
|
||||
SyncWaitDataset, dataset added a blocking condition.
|
||||
|
@ -398,7 +398,7 @@ class Dataset:
|
|||
|
||||
1. Make a shuffle buffer that contains the first buffer_size rows.
|
||||
2. Randomly select an element from the shuffle buffer to be the next row
|
||||
propogated to the child node.
|
||||
propagated to the child node.
|
||||
3. Get the next row (if any) from the parent node and put it in the shuffle buffer.
|
||||
4. Repeat steps 2 and 3 until there are no more rows left in the shuffle buffer.
|
||||
|
||||
|
@ -1649,8 +1649,7 @@ class MappableDataset(SourceDataset):
|
|||
def add_sampler(self, new_sampler):
|
||||
# note: By adding a sampler, the sampled IDs will flow to new_sampler
|
||||
# after first passing through the current samplers attached to this dataset.
|
||||
if self.dataset_size is not None:
|
||||
self.dataset_size = None
|
||||
self.dataset_size = None
|
||||
new_sampler.add_child(self.sampler)
|
||||
self.sampler = new_sampler
|
||||
|
||||
|
@ -1676,8 +1675,7 @@ class MappableDataset(SourceDataset):
|
|||
raise TypeError("Input sampler can not be None.")
|
||||
if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
||||
raise TypeError("Input sampler is not an instance of a sampler.")
|
||||
if self.dataset_size is not None:
|
||||
self.dataset_size = None
|
||||
self.dataset_size = None
|
||||
|
||||
self.sampler = self.sampler.child_sampler
|
||||
self.add_sampler(new_sampler)
|
||||
|
@ -1718,7 +1716,7 @@ class MappableDataset(SourceDataset):
|
|||
- The sum of split sizes < K, the difference will be added to the first split.
|
||||
|
||||
- The sum of split sizes > K, the difference will be removed from the first large
|
||||
enough split such that it will have atleast 1 row after removing the difference.
|
||||
enough split such that it will have at least 1 row after removing the difference.
|
||||
|
||||
randomize (bool, optional): Determines whether or not to split the data randomly (default=True).
|
||||
If True, the data will be randomly split. Otherwise, each split will be created with
|
||||
|
@ -2647,6 +2645,8 @@ class ConcatDataset(Dataset):
|
|||
if sampler.get_num_samples() is not None:
|
||||
raise ValueError("The parameter num_samples of DistributedSampler is not support to be set!")
|
||||
|
||||
self.dataset_size = None
|
||||
|
||||
self._sampler = _select_sampler(None, sampler, None, None, None)
|
||||
cumulative_samples_nums = 0
|
||||
for index, child in enumerate(self.children):
|
||||
|
|
|
@ -53,7 +53,7 @@ TEST_F(MindDataTestTreeModifying, AppendChild) {
|
|||
std::shared_ptr<Dataset> ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds6 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds3 = ds1->Take(10);
|
||||
std::shared_ptr<Dataset> ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!!
|
||||
std::shared_ptr<Dataset> ds4 = ds3->Concat({ds2});
|
||||
Status rc;
|
||||
|
||||
std::shared_ptr<DatasetNode> root = ds4->IRNode();
|
||||
|
@ -110,7 +110,7 @@ TEST_F(MindDataTestTreeModifying, InsertChildAt01) {
|
|||
std::shared_ptr<Dataset> ds3 = ds1->Take(10);
|
||||
std::shared_ptr<Dataset> ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds2 = ds5->Repeat(4);
|
||||
std::shared_ptr<Dataset> ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!!
|
||||
std::shared_ptr<Dataset> ds4 = ds3->Concat({ds2});
|
||||
Status rc;
|
||||
std::shared_ptr<DatasetNode> root = ds4->IRNode();
|
||||
auto ir_tree = std::make_shared<TreeAdapter>();
|
||||
|
@ -173,7 +173,7 @@ TEST_F(MindDataTestTreeModifying, InsertChildAt04) {
|
|||
std::shared_ptr<Dataset> ds3 = ds1->Take(10);
|
||||
std::shared_ptr<Dataset> ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds2 = ds5->Repeat(4);
|
||||
std::shared_ptr<Dataset> ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!!
|
||||
std::shared_ptr<Dataset> ds4 = ds3->Concat({ds2});
|
||||
Status rc;
|
||||
std::shared_ptr<DatasetNode> root = ds4->IRNode();
|
||||
auto ir_tree = std::make_shared<TreeAdapter>();
|
||||
|
@ -253,7 +253,7 @@ TEST_F(MindDataTestTreeModifying, InsertAbove01) {
|
|||
std::shared_ptr<Dataset> ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds3 = ds1->Take(10);
|
||||
std::shared_ptr<Dataset> ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!!
|
||||
std::shared_ptr<Dataset> ds4 = ds3->Concat({ds2});
|
||||
Status rc;
|
||||
|
||||
std::shared_ptr<DatasetNode> root = ds4->IRNode();
|
||||
|
@ -280,7 +280,7 @@ TEST_F(MindDataTestTreeModifying, InsertAbove02) {
|
|||
std::shared_ptr<Dataset> ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds3 = ds1->Take(10);
|
||||
std::shared_ptr<Dataset> ds4 = ds2 + ds3; // ds2 is the second child and ds3 is the first child!!!
|
||||
std::shared_ptr<Dataset> ds4 = ds3 + ds2;
|
||||
Status rc;
|
||||
|
||||
std::shared_ptr<DatasetNode> root = ds4->IRNode();
|
||||
|
@ -307,7 +307,7 @@ TEST_F(MindDataTestTreeModifying, InsertAbove03) {
|
|||
std::shared_ptr<Dataset> ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds3 = ds1->Take(10);
|
||||
std::shared_ptr<Dataset> ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!!
|
||||
std::shared_ptr<Dataset> ds4 = ds3->Concat({ds2});
|
||||
Status rc;
|
||||
|
||||
std::shared_ptr<DatasetNode> root = ds4->IRNode();
|
||||
|
@ -372,9 +372,9 @@ TEST_F(MindDataTestTreeModifying, Drop01) {
|
|||
std::shared_ptr<Dataset> ds9 = ds8->Skip(1);
|
||||
std::shared_ptr<Dataset> ds3 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!!
|
||||
std::shared_ptr<Dataset> ds4 = ds3->Concat({ds2});
|
||||
std::shared_ptr<Dataset> ds6 = ds4->Take(13);
|
||||
std::shared_ptr<Dataset> ds10 = ds6 + ds9;
|
||||
std::shared_ptr<Dataset> ds10 = ds9 + ds6;
|
||||
Status rc;
|
||||
|
||||
std::shared_ptr<DatasetNode> root = ds10->IRNode();
|
||||
|
@ -437,9 +437,9 @@ TEST_F(MindDataTestTreeModifying, Drop03) {
|
|||
std::shared_ptr<Dataset> ds9 = ds8->Skip(1);
|
||||
std::shared_ptr<Dataset> ds3 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!!
|
||||
std::shared_ptr<Dataset> ds4 = ds3->Concat({ds2});
|
||||
std::shared_ptr<Dataset> ds6 = ds4->Take(13);
|
||||
std::shared_ptr<Dataset> ds10 = ds6 + ds9;
|
||||
std::shared_ptr<Dataset> ds10 = ds9 + ds6;
|
||||
Status rc;
|
||||
|
||||
std::shared_ptr<DatasetNode> root = ds10->IRNode();
|
||||
|
@ -487,11 +487,11 @@ TEST_F(MindDataTestTreeModifying, Drop04) {
|
|||
std::shared_ptr<Dataset> ds9 = ds8->Skip(1);
|
||||
std::shared_ptr<Dataset> ds3 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!!
|
||||
std::shared_ptr<Dataset> ds4 = ds3->Concat({ds2});
|
||||
std::shared_ptr<Dataset> ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds6 = ds1->Concat({ds5, ds4}); // ds1 is put after (ds5, ds4)!!!
|
||||
std::shared_ptr<Dataset> ds10 = ds6 + ds9;
|
||||
std::shared_ptr<Dataset> ds6 = ds5->Concat({ds4, ds1});
|
||||
std::shared_ptr<Dataset> ds10 = ds9 + ds6;
|
||||
Status rc;
|
||||
|
||||
std::shared_ptr<DatasetNode> root = ds10->IRNode();
|
||||
|
@ -548,8 +548,8 @@ TEST_F(MindDataTestTreeModifying, Drop05) {
|
|||
std::shared_ptr<Dataset> ds4 = ds3->Skip(1);
|
||||
std::shared_ptr<Dataset> ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds6 = ds1->Concat({ds5, ds4}); // ds1 is put after (ds5, ds4)!!!
|
||||
std::shared_ptr<Dataset> ds10 = ds6 + ds9;
|
||||
std::shared_ptr<Dataset> ds6 = ds5->Concat({ds4, ds1});
|
||||
std::shared_ptr<Dataset> ds10 = ds9 + ds6;
|
||||
Status rc;
|
||||
|
||||
std::shared_ptr<DatasetNode> root = ds10->IRNode();
|
||||
|
@ -603,11 +603,11 @@ TEST_F(MindDataTestTreeModifying, Drop06) {
|
|||
std::shared_ptr<Dataset> ds9 = ds8->Skip(1);
|
||||
std::shared_ptr<Dataset> ds3 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds2 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds4 = ds2->Concat({ds3}); // ds2 is the second child and ds3 is the first child!!!
|
||||
std::shared_ptr<Dataset> ds4 = ds3->Concat({ds2});
|
||||
std::shared_ptr<Dataset> ds5 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds1 = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
std::shared_ptr<Dataset> ds6 = ds1->Concat({ds5, ds4}); // ds1 is put after (ds5, ds4)!!!
|
||||
std::shared_ptr<Dataset> ds10 = ds6 + ds9;
|
||||
std::shared_ptr<Dataset> ds6 = ds5->Concat({ds4, ds1}); // ds1 is put after (ds5, ds4)!!!
|
||||
std::shared_ptr<Dataset> ds10 = ds9 + ds6;
|
||||
Status rc;
|
||||
|
||||
std::shared_ptr<DatasetNode> root = ds10->IRNode();
|
||||
|
|
|
@ -33,12 +33,19 @@ def generator_10():
|
|||
for i in range(3, 10):
|
||||
yield (np.array([i]),)
|
||||
|
||||
|
||||
# In generator_20 dataset: Number of rows is 10; its values are 10, 11, 12 ... 19
|
||||
def generator_20():
|
||||
for i in range(10, 20):
|
||||
yield (np.array([i]),)
|
||||
|
||||
|
||||
# In generator_29 dataset: Number of rows is 9; its values are 20, 21, 22 ... 28
|
||||
def generator_29():
|
||||
for i in range(20, 29):
|
||||
yield (np.array([i]),)
|
||||
|
||||
|
||||
def test_concat_01():
|
||||
"""
|
||||
Test concat: test concat 2 datasets that have the same column name and data type
|
||||
|
@ -316,7 +323,7 @@ def test_concat_13():
|
|||
|
||||
def test_concat_14():
|
||||
"""
|
||||
Test concat: create dataset with different dataset folder, and do diffrent operation then concat
|
||||
Test concat: Testing concat on two different source datasets with different dataset operations.
|
||||
"""
|
||||
logger.info("test_concat_14")
|
||||
DATA_DIR = "../data/dataset/testPK/data"
|
||||
|
@ -365,6 +372,63 @@ def test_concat_15():
|
|||
assert sum([1 for _ in data3]) == 47
|
||||
|
||||
|
||||
def test_concat_16():
|
||||
"""
|
||||
Test concat: test get_dataset_size on nested concats
|
||||
"""
|
||||
logger.info("test_concat_16")
|
||||
DATA_DIR = "../data/dataset/testPK/data"
|
||||
DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
|
||||
data1 = ds.ImageFolderDataset(DATA_DIR)
|
||||
data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"])
|
||||
|
||||
data3 = ds.GeneratorDataset(generator, ["col1"])
|
||||
data4 = ds.GeneratorDataset(generator_10, ["col1"])
|
||||
|
||||
data5 = data1 + data2
|
||||
data6 = data3 + data4
|
||||
data7 = data5 + data6
|
||||
|
||||
ds.config.set_seed(1)
|
||||
|
||||
# 57 is the total size of all 4 leaf datasets
|
||||
assert data7.get_dataset_size() == 57
|
||||
|
||||
|
||||
def test_concat_17():
|
||||
"""
|
||||
Test concat: test get_dataset_size on nested concats (with sampler)
|
||||
"""
|
||||
logger.info("test_concat_17")
|
||||
|
||||
data1 = ds.GeneratorDataset(generator, ["col1"])
|
||||
data2 = ds.GeneratorDataset(generator_10, ["col1"])
|
||||
|
||||
data3 = ds.GeneratorDataset(generator_20, ["col1"])
|
||||
data4 = ds.GeneratorDataset(generator_29, ["col1"])
|
||||
|
||||
data5 = data1 + data2
|
||||
data6 = data3 + data4
|
||||
data7 = data5 + data6
|
||||
|
||||
ds.config.set_seed(1)
|
||||
shard_num = 10
|
||||
counter = 0
|
||||
|
||||
for i in range(shard_num):
|
||||
distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
|
||||
data7.use_sampler(distributed_sampler)
|
||||
iter_counter = 0
|
||||
for _ in data7.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
counter += 1
|
||||
iter_counter += 1
|
||||
assert data7.get_dataset_size() == iter_counter
|
||||
|
||||
# 29 is the total size of all 4 leaf datasets
|
||||
assert counter == 29
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_concat_01()
|
||||
test_concat_02()
|
||||
|
@ -381,3 +445,5 @@ if __name__ == "__main__":
|
|||
test_concat_13()
|
||||
test_concat_14()
|
||||
test_concat_15()
|
||||
test_concat_16()
|
||||
test_concat_17()
|
||||
|
|
Loading…
Reference in New Issue