Fixing api leftover

Fixed compile error

Fixed more testcases
This commit is contained in:
Eric 2021-03-01 10:32:24 -05:00
parent 9493094f9f
commit f9a2379a70
15 changed files with 202 additions and 140 deletions

View File

@ -25,6 +25,7 @@
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
namespace mindspore {
namespace dataset {

View File

@ -22,6 +22,7 @@
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
namespace mindspore {
namespace dataset {

View File

@ -138,11 +138,12 @@ Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_s
// Helper function to validate dataset sampler parameter
Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) {
if (sampler == nullptr || sampler->ValidateParams().IsError()) {
if (sampler == nullptr) {
std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(sampler->ValidateParams());
return Status::OK();
}

View File

@ -36,6 +36,7 @@
#include "minddata/dataset/engine/datasetops/skip_op.h"
#include "minddata/dataset/engine/datasetops/take_op.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/status.h"

View File

@ -45,9 +45,7 @@ class TensorRow;
class TensorShape;
class TreeAdapter;
class TreeGetters;
#ifndef ENABLE_ANDROID
class Vocab;
#endif
class DatasetCache;
class DatasetNode;
@ -64,31 +62,23 @@ class BatchDataset;
class MapDataset;
class ProjectDataset;
class ShuffleDataset;
#ifndef ENABLE_ANDROID
class BucketBatchByLengthDataset;
class FilterDataset;
class CSVDataset;
class TransferDataset;
class ConcatDataset;
class RenameDataset;
#endif
#ifndef ENABLE_ANDROID
class SentencePieceVocab;
enum class SentencePieceModel;
#endif
class DSCallback;
class RepeatDataset;
#ifndef ENABLE_ANDROID
class SkipDataset;
class TakeDataset;
class ZipDataset;
#endif
/// \class Dataset datasets.h
/// \brief A base class to represent a dataset in the data pipeline.
class Dataset : public std::enable_shared_from_this<Dataset> {
@ -153,7 +143,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs);
}
#ifndef ENABLE_ANDROID
/// \brief Function to transfer data through a device.
/// \notes If device is Ascend, features of data will be transferred one by one. The limitation
/// of data transmission per time is 256M.
@ -186,7 +175,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
bool Save(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") {
return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type));
}
#endif
/// \brief Function to create a BatchDataset
/// \notes Combines batch_size number of consecutive rows into batches
@ -198,7 +186,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current BatchDataset
std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
#ifndef ENABLE_ANDROID
/// \brief Function to create a BucketBatchByLengthDataset
/// \notes Bucket elements according to their lengths. Each bucket will be padded and batched when
/// they are full.
@ -293,7 +280,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
const std::vector<std::string> &input_columns = {}) {
return std::make_shared<FilterDataset>(shared_from_this(), predicate, VectorStringToChar(input_columns));
}
#endif
/// \brief Function to create a MapDataset
/// \notes Applies each operation in operations to this dataset
@ -396,7 +382,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
return std::make_shared<ProjectDataset>(shared_from_this(), VectorStringToChar(columns));
}
#ifndef ENABLE_ANDROID
/// \brief Function to create a Rename Dataset
/// \notes Renames the columns in the input dataset
/// \param[in] input_columns List of the input columns to rename
@ -407,7 +392,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
return std::make_shared<RenameDataset>(shared_from_this(), VectorStringToChar(input_columns),
VectorStringToChar(output_columns));
}
#endif
/// \brief Function to create a RepeatDataset
/// \notes Repeats this dataset count times. Repeat indefinitely if count is -1
/// \param[in] count Number of times the dataset should be repeated
@ -417,7 +401,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
std::shared_ptr<RepeatDataset> Repeat(int32_t count = -1) {
return std::make_shared<RepeatDataset>(shared_from_this(), count);
}
#ifndef ENABLE_ANDROID
/// \brief Function to create a Shuffle Dataset
/// \notes Randomly shuffles the rows of this dataset
/// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
@ -449,7 +432,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
all_datasets.push_back(shared_from_this());
return std::make_shared<ZipDataset>(all_datasets);
}
#endif
std::shared_ptr<DatasetNode> IRNode() { return ir_node_; }
@ -602,7 +584,6 @@ class BatchDataset : public Dataset {
~BatchDataset() = default;
};
#ifndef ENABLE_ANDROID
class BucketBatchByLengthDataset : public Dataset {
public:
BucketBatchByLengthDataset(
@ -626,7 +607,6 @@ class FilterDataset : public Dataset {
const std::vector<std::vector<char>> &input_columns);
~FilterDataset() = default;
};
#endif
class MapDataset : public Dataset {
public:
@ -643,14 +623,12 @@ class ProjectDataset : public Dataset {
~ProjectDataset() = default;
};
#ifndef ENABLE_ANDROID
class RenameDataset : public Dataset {
public:
RenameDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &input_columns,
const std::vector<std::vector<char>> &output_columns);
~RenameDataset() = default;
};
#endif
class RepeatDataset : public Dataset {
public:
@ -664,7 +642,6 @@ class ShuffleDataset : public Dataset {
~ShuffleDataset() = default;
};
#ifndef ENABLE_ANDROID
class SkipDataset : public Dataset {
public:
SkipDataset(std::shared_ptr<Dataset> input, int32_t count);
@ -682,7 +659,6 @@ class ZipDataset : public Dataset {
explicit ZipDataset(const std::vector<std::shared_ptr<Dataset>> &inputs);
~ZipDataset() = default;
};
#endif
/// \brief Function to create a SchemaObj
/// \param[in] schema_file Path of schema file
@ -762,7 +738,6 @@ inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const
VectorStringToChar(column_names), decode, sampler, cache);
}
#ifndef ENABLE_ANDROID
class CelebADataset : public Dataset {
public:
explicit CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
@ -1375,7 +1350,6 @@ inline std::shared_ptr<MindDataDataset> MindData(const std::vector<std::string>
return std::make_shared<MindDataDataset>(VectorStringToChar(dataset_files), VectorStringToChar(columns_list), sampler,
padded_sample, num_padded);
}
#endif
class MnistDataset : public Dataset {
public:
@ -1427,7 +1401,6 @@ inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
#ifndef ENABLE_ANDROID
/// \brief Function to create a ConcatDataset
/// \notes Reload "+" operator to concat two datasets
@ -1686,7 +1659,6 @@ inline std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint
inline std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
return std::make_shared<ZipDataset>(datasets);
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -608,7 +608,7 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
} else {
// we resize here since the shape changes
// create a new bounding box with the rotate
cv::Rect2f bbox = cv::RotatedRect(cv::Point2f(), input_img.size(), degree).boundingRect2f();
cv::Rect2f bbox = cv::RotatedRect(pc, input_img.size(), degree).boundingRect2f();
rot.at<double>(0, 2) += bbox.width / 2.0 - input_img.cols / 2.0;
rot.at<double>(1, 2) += bbox.height / 2.0 - input_img.rows / 2.0;
// use memcpy and don't compute the new shape since openCV has a rounding problem

View File

@ -3671,7 +3671,7 @@ class ManifestDataset(MappableDataset):
decode (bool, optional): decode the images after reading (default=False).
num_shards (int, optional): Number of shards that the dataset will be divided
into (default=None). When this argument is specified, `num_samples` reflects
the max sample number of per shard.
the max number of samples per shard.
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
argument can only be specified when `num_shards` is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.

View File

@ -74,6 +74,7 @@ SET(DE_UT_SRCS
image_process_test.cc
interrupt_test.cc
ir_callback_test.cc
ir_sampler_test.cc
ir_tensor_op_fusion_pass_test.cc
ir_tree_adapter_test.cc
ir_vision_test.cc

View File

@ -57,10 +57,10 @@ TEST_F(MindDataTestPipeline, TestAffineAPI) {
uint64_t i = 0;
while (row.size() != 0) {
i++;
// auto image = row["image"];
// MS_LOG(INFO) << "Tensor image shape: " << image->shape();
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
EXPECT_EQ(row["image"].Shape().at(0), 256);
iter->GetNextRow(&row);
// EXPECT_EQ(row["image"].Shape()[0], 256);
}
EXPECT_EQ(i, 15);

View File

@ -88,96 +88,6 @@ TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestCalculateNumSamples) {
int64_t num_rows = 30; // dummy variable for number of rows in the dataset
std::shared_ptr<SamplerObj> sampl = std::make_shared<DistributedSamplerObj>(2, 1, false, 6, 1, -1, true);
EXPECT_NE(sampl, nullptr);
std::shared_ptr<SamplerRT> sampler_rt;
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6);
sampl = std::make_shared<PKSamplerObj>(3, false, 0);
EXPECT_NE(sampl, nullptr);
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30);
sampl = std::make_shared<RandomSamplerObj>(false, 12);
EXPECT_NE(sampl, nullptr);
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12);
sampl = std::make_shared<SequentialSamplerObj>(0, 10);
EXPECT_NE(sampl, nullptr);
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10);
std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1};
sampl = std::make_shared<WeightedRandomSamplerObj>(weights, 12);
EXPECT_NE(sampl, nullptr);
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12);
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21};
sampl = std::make_shared<SubsetRandomSamplerObj>(indices, 11);
EXPECT_NE(sampl, nullptr);
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11);
// Testing chains
// Parent and child have num_samples
std::shared_ptr<SamplerObj> sampl1 = std::make_shared<WeightedRandomSamplerObj>(weights, 12);
EXPECT_NE(sampl1, nullptr);
std::shared_ptr<SamplerRT> sampler_rt1;
sampl1->SamplerBuild(&sampler_rt1);
std::shared_ptr<SamplerObj> sampl2 = std::make_shared<SequentialSamplerObj>(0, 10);
EXPECT_NE(sampl2, nullptr);
std::shared_ptr<SamplerRT> sampler_rt2;
sampl2->SamplerBuild(&sampler_rt2);
sampler_rt2->AddChild(sampler_rt1);
EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10);
// Parent doesn't have num_samples
std::shared_ptr<SamplerObj> sampl3 = std::make_shared<WeightedRandomSamplerObj>(weights, 12);
EXPECT_NE(sampl3, nullptr);
std::shared_ptr<SamplerRT> sampler_rt3;
sampl3->SamplerBuild(&sampler_rt3);
std::shared_ptr<SamplerObj> sampl4 = std::make_shared<SubsetRandomSamplerObj>(indices, 0);
EXPECT_NE(sampl4, nullptr);
std::shared_ptr<SamplerRT> sampler_rt4;
sampl4->SamplerBuild(&sampler_rt4);
sampler_rt4->AddChild(sampler_rt3);
EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 11);
// Child doesn't have num_samples
std::shared_ptr<SamplerObj> sampl5 = std::make_shared<RandomSamplerObj>(false, 0);
EXPECT_NE(sampl5, nullptr);
std::shared_ptr<SamplerRT> sampler_rt5;
sampl5->SamplerBuild(&sampler_rt5);
std::shared_ptr<SamplerObj> sampl6 = std::make_shared<PKSamplerObj>(3, false, 7);
EXPECT_NE(sampl6, nullptr);
std::shared_ptr<SamplerRT> sampler_rt6;
sampl6->SamplerBuild(&sampler_rt6);
sampler_rt6->AddChild(sampler_rt5);
EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7);
}
TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) {
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23};
std::shared_ptr<SamplerObj> sampl1 = std::make_shared<SubsetRandomSamplerObj>(indices, 0);
EXPECT_FALSE(indices.empty());
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
sampl1->SamplerBuild(&sampler_rt);
EXPECT_NE(sampler_rt, nullptr);
std::shared_ptr<SamplerObj> sampl2 = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), 0);
EXPECT_TRUE(indices.empty());
std::shared_ptr<SamplerRT> sampler_rt2 = nullptr;
sampl2->SamplerBuild(&sampler_rt2);
EXPECT_NE(sampler_rt, nullptr);
}
TEST_F(MindDataTestPipeline, TestNoSamplerSuccess1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNoSamplerSuccess1.";
// Test building a dataset with no sampler provided (defaults to random sampler

View File

@ -37,8 +37,8 @@ TEST_F(MindDataTestPipeline, TestComposeSuccess) {
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
auto decode_op(new vision::Decode());
auto resize_op(new vision::Resize({777, 777}));
std::shared_ptr<TensorTransform> decode_op(new vision::Decode());
std::shared_ptr<TensorTransform> resize_op(new vision::Resize({777, 777}));
transforms::Compose compose({decode_op, resize_op});
// Create a Map operation on ds
@ -493,10 +493,10 @@ TEST_F(MindDataTestPipeline, TestRandomChoiceSuccess) {
uint64_t i = 0;
while (row.size() != 0) {
i++;
// auto image = row["image"];
// auto label = row["label"];
// MS_LOG(INFO) << "Tensor image shape: " << image->shape();
// MS_LOG(INFO) << "Label shape: " << label->shape();
auto image = row["image"];
auto label = row["label"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
MS_LOG(INFO) << "Label shape: " << label.Shape();
iter->GetNextRow(&row);
}

View File

@ -56,9 +56,9 @@ TEST_F(MindDataTestPipeline, TestSoftDvppDecodeRandomCropResizeJpegSuccess1) {
uint64_t i = 0;
while (row.size() != 0) {
i++;
// auto image = row["image"];
// MS_LOG(INFO) << "Tensor image shape: " << image->shape();
// EXPECT_EQ(image->shape()[0] == 500 && image->shape()[1] == 500, true);
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
EXPECT_EQ(image.Shape()[0] == 500 && image.Shape()[1] == 500, true);
iter->GetNextRow(&row);
}
@ -98,9 +98,9 @@ TEST_F(MindDataTestPipeline, TestSoftDvppDecodeRandomCropResizeJpegSuccess2) {
uint64_t i = 0;
while (row.size() != 0) {
i++;
// auto image = row["image"];
// MS_LOG(INFO) << "Tensor image shape: " << image->shape();
// EXPECT_EQ(image->shape()[0] == 500 && image->shape()[1] == 600, true);
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
EXPECT_EQ(image.Shape()[0] == 500 && image.Shape()[1] == 600, true);
iter->GetNextRow(&row);
}
@ -142,7 +142,7 @@ TEST_F(MindDataTestPipeline, TestSoftDvppDecodeResizeJpegSuccess1) {
uint64_t i = 0;
while (row.size() != 0) {
i++;
// auto image = row["image"];
// std::shared_ptr<TensorTransform> image = row["image"];
// MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
@ -180,8 +180,8 @@ TEST_F(MindDataTestPipeline, TestSoftDvppDecodeResizeJpegSuccess2) {
uint64_t i = 0;
while (row.size() != 0) {
i++;
// auto image = row["image"];
// MS_LOG(INFO) << "Tensor image shape: " << image->shape();
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
iter->GetNextRow(&row);
}

View File

@ -112,6 +112,8 @@ TEST_F(MindDataTestExecute, TestTransformInput2) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformInput2.";
// Test Execute with transform op input using API constructors, with std::shared_ptr<TensorTransform pointers,
// instantiated via new
// With this way of creating TensorTransforms, we don't need to explicitly delete the object created with the
// "new" keyword. When the shared pointer goes out of scope the object destructor will be called.
// Read image, construct MSTensor from dataset tensor
std::shared_ptr<mindspore::dataset::Tensor> de_tensor;
@ -171,6 +173,7 @@ TEST_F(MindDataTestExecute, TestTransformInput3) {
TEST_F(MindDataTestExecute, TestTransformInputSequential) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformInputSequential.";
// Test Execute with transform op input using API constructors, with auto pointers;
// Note that with auto and new, we have to explicitly delete the allocated object as shown below.
// Apply 2 transformations sequentially, including single non-vector Transform op input
// Read image, construct MSTensor from dataset tensor
@ -207,7 +210,7 @@ TEST_F(MindDataTestExecute, TestTransformInputSequential) {
TEST_F(MindDataTestExecute, TestTransformDecodeResizeCenterCrop1) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformDecodeResizeCenterCrop1.";
// Test Execute with Decode, Resize and CenterCrop transform ops input using API constructors, with auto pointers
// Test Execute with Decode, Resize and CenterCrop transform ops input using API constructors, with shared pointers
// Read image, construct MSTensor from dataset tensor
std::shared_ptr<mindspore::dataset::Tensor> de_tensor;

View File

@ -0,0 +1,116 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
#include "minddata/dataset/core/tensor.h"
using namespace mindspore::dataset;
using mindspore::dataset::Tensor;
class MindDataTestIrSampler : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestIrSampler, TestCalculateNumSamples) {
int64_t num_rows = 30; // dummy variable for number of rows in the dataset
std::shared_ptr<SamplerObj> sampl = std::make_shared<DistributedSamplerObj>(2, 1, false, 6, 1, -1, true);
EXPECT_NE(sampl, nullptr);
std::shared_ptr<SamplerRT> sampler_rt;
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6);
sampl = std::make_shared<PKSamplerObj>(3, false, 0);
EXPECT_NE(sampl, nullptr);
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30);
sampl = std::make_shared<RandomSamplerObj>(false, 12);
EXPECT_NE(sampl, nullptr);
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12);
sampl = std::make_shared<SequentialSamplerObj>(0, 10);
EXPECT_NE(sampl, nullptr);
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10);
std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1};
sampl = std::make_shared<WeightedRandomSamplerObj>(weights, 12);
EXPECT_NE(sampl, nullptr);
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12);
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21};
sampl = std::make_shared<SubsetRandomSamplerObj>(indices, 11);
EXPECT_NE(sampl, nullptr);
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11);
// Testing chains
// Parent and child have num_samples
std::shared_ptr<SamplerObj> sampl1 = std::make_shared<WeightedRandomSamplerObj>(weights, 12);
EXPECT_NE(sampl1, nullptr);
std::shared_ptr<SamplerRT> sampler_rt1;
sampl1->SamplerBuild(&sampler_rt1);
std::shared_ptr<SamplerObj> sampl2 = std::make_shared<SequentialSamplerObj>(0, 10);
EXPECT_NE(sampl2, nullptr);
std::shared_ptr<SamplerRT> sampler_rt2;
sampl2->SamplerBuild(&sampler_rt2);
sampler_rt2->AddChild(sampler_rt1);
EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10);
// Parent doesn't have num_samples
std::shared_ptr<SamplerObj> sampl3 = std::make_shared<WeightedRandomSamplerObj>(weights, 12);
EXPECT_NE(sampl3, nullptr);
std::shared_ptr<SamplerRT> sampler_rt3;
sampl3->SamplerBuild(&sampler_rt3);
std::shared_ptr<SamplerObj> sampl4 = std::make_shared<SubsetRandomSamplerObj>(indices, 0);
EXPECT_NE(sampl4, nullptr);
std::shared_ptr<SamplerRT> sampler_rt4;
sampl4->SamplerBuild(&sampler_rt4);
sampler_rt4->AddChild(sampler_rt3);
EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 11);
// Child doesn't have num_samples
std::shared_ptr<SamplerObj> sampl5 = std::make_shared<RandomSamplerObj>(false, 0);
EXPECT_NE(sampl5, nullptr);
std::shared_ptr<SamplerRT> sampler_rt5;
sampl5->SamplerBuild(&sampler_rt5);
std::shared_ptr<SamplerObj> sampl6 = std::make_shared<PKSamplerObj>(3, false, 7);
EXPECT_NE(sampl6, nullptr);
std::shared_ptr<SamplerRT> sampler_rt6;
sampl6->SamplerBuild(&sampler_rt6);
sampler_rt6->AddChild(sampler_rt5);
EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7);
}
TEST_F(MindDataTestIrSampler, TestSamplersMoveParameters) {
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23};
std::shared_ptr<SamplerObj> sampl1 = std::make_shared<SubsetRandomSamplerObj>(indices, 0);
EXPECT_FALSE(indices.empty());
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
sampl1->SamplerBuild(&sampler_rt);
EXPECT_NE(sampler_rt, nullptr);
std::shared_ptr<SamplerObj> sampl2 = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), 0);
EXPECT_TRUE(indices.empty());
std::shared_ptr<SamplerRT> sampler_rt2 = nullptr;
sampl2->SamplerBuild(&sampler_rt2);
EXPECT_NE(sampler_rt, nullptr);
}

View File

@ -0,0 +1,56 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger
def test_num_samples():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
num_samples = 1
# sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1)
data1 = ds.ManifestDataset(
manifest_file, num_samples=num_samples, num_shards=3, shard_id=1
)
row_count = 0
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
row_count += 1
assert row_count == 1
def test_num_samples_tf():
logger.info("test_tfrecord_read_all_dataset")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
files = ["../data/dataset/testTFTestAllTypes/test.data"]
# here num samples indicate the rows per shard. Total rows in file = 12
ds1 = ds.TFRecordDataset(files, schema_file, num_samples=2)
count = 0
for _ in ds1.create_tuple_iterator(num_epochs=1):
count += 1
assert count == 2
def test_num_samples_image_folder():
data_dir = "../data/dataset/testPK/data"
ds1 = ds.ImageFolderDataset(data_dir, num_samples=2, num_shards=2, shard_id=0)
count = 0
for _ in ds1.create_tuple_iterator(num_epochs=1):
count += 1
assert count == 2
if __name__ == "__main__":
test_num_samples()
test_num_samples_tf()
test_num_samples_image_folder()