From f9a2379a70e0ee926b8ef623de543fcad15f7175 Mon Sep 17 00:00:00 2001 From: Eric Date: Mon, 1 Mar 2021 10:32:24 -0500 Subject: [PATCH] Fixing api leftover Fixed compile error Fixed more testcases --- .../engine/ir/cache/dataset_cache_impl.h | 1 + .../engine/ir/cache/pre_built_dataset_cache.h | 1 + .../engine/ir/datasetops/dataset_node.cc | 3 +- .../engine/ir/datasetops/dataset_node.h | 1 + .../ccsrc/minddata/dataset/include/datasets.h | 28 ----- .../dataset/kernels/image/image_utils.cc | 2 +- mindspore/dataset/engine/datasets.py | 2 +- tests/ut/cpp/dataset/CMakeLists.txt | 1 + tests/ut/cpp/dataset/c_api_affine_test.cc | 6 +- tests/ut/cpp/dataset/c_api_samplers_test.cc | 90 -------------- tests/ut/cpp/dataset/c_api_transforms_test.cc | 12 +- .../dataset/c_api_vision_soft_dvpp_test.cc | 18 +-- tests/ut/cpp/dataset/execute_test.cc | 5 +- tests/ut/cpp/dataset/ir_sampler_test.cc | 116 ++++++++++++++++++ tests/ut/python/dataset/test_num_samples.py | 56 +++++++++ 15 files changed, 202 insertions(+), 140 deletions(-) create mode 100644 tests/ut/cpp/dataset/ir_sampler_test.cc create mode 100644 tests/ut/python/dataset/test_num_samples.py diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h index 6843eb3183e..f4f7c7c2450 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h @@ -25,6 +25,7 @@ #include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/datasetops/cache_op.h" #include "minddata/dataset/engine/ir/cache/dataset_cache.h" +#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" namespace mindspore { namespace dataset { diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h index d1588ce19c7..83faa7e37c7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h @@ -22,6 +22,7 @@ #include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/datasetops/cache_op.h" #include "minddata/dataset/engine/ir/cache/dataset_cache.h" +#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" namespace mindspore { namespace dataset { diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index cf370714ac5..2a1fb556ac8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -138,11 +138,12 @@ Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_s // Helper function to validate dataset sampler parameter Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr &sampler) { - if (sampler == nullptr || sampler->ValidateParams().IsError()) { + if (sampler == nullptr) { std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr"; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } + RETURN_IF_NOT_OK(sampler->ValidateParams()); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index e95251637bf..0695335288a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -36,6 +36,7 @@ #include "minddata/dataset/engine/datasetops/skip_op.h" #include "minddata/dataset/engine/datasetops/take_op.h" #include "minddata/dataset/engine/ir/cache/dataset_cache.h" +#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" #include "minddata/dataset/include/datasets.h" #include "minddata/dataset/util/path.h" #include "minddata/dataset/util/status.h" diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index c565a89cf76..58720314c98 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -45,9 +45,7 @@ class TensorRow; class TensorShape; class TreeAdapter; class TreeGetters; -#ifndef ENABLE_ANDROID class Vocab; -#endif class DatasetCache; class DatasetNode; @@ -64,31 +62,23 @@ class BatchDataset; class MapDataset; class ProjectDataset; class ShuffleDataset; -#ifndef ENABLE_ANDROID class BucketBatchByLengthDataset; class FilterDataset; class CSVDataset; class TransferDataset; class ConcatDataset; class RenameDataset; -#endif -#ifndef ENABLE_ANDROID class SentencePieceVocab; enum class SentencePieceModel; -#endif class DSCallback; class RepeatDataset; - -#ifndef ENABLE_ANDROID class SkipDataset; class TakeDataset; class ZipDataset; -#endif - /// \class Dataset datasets.h /// \brief A base class to represent a dataset in the data pipeline. class Dataset : public std::enable_shared_from_this { @@ -153,7 +143,6 @@ class Dataset : public std::enable_shared_from_this { return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs); } -#ifndef ENABLE_ANDROID /// \brief Function to transfer data through a device. /// \notes If device is Ascend, features of data will be transferred one by one. The limitation /// of data transmission per time is 256M. @@ -186,7 +175,6 @@ class Dataset : public std::enable_shared_from_this { bool Save(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") { return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type)); } -#endif /// \brief Function to create a BatchDataset /// \notes Combines batch_size number of consecutive rows into batches @@ -198,7 +186,6 @@ class Dataset : public std::enable_shared_from_this { /// \return Shared pointer to the current BatchDataset std::shared_ptr Batch(int32_t batch_size, bool drop_remainder = false); -#ifndef ENABLE_ANDROID /// \brief Function to create a BucketBatchByLengthDataset /// \notes Bucket elements according to their lengths. Each bucket will be padded and batched when /// they are full. @@ -293,7 +280,6 @@ class Dataset : public std::enable_shared_from_this { const std::vector &input_columns = {}) { return std::make_shared(shared_from_this(), predicate, VectorStringToChar(input_columns)); } -#endif /// \brief Function to create a MapDataset /// \notes Applies each operation in operations to this dataset @@ -396,7 +382,6 @@ class Dataset : public std::enable_shared_from_this { return std::make_shared(shared_from_this(), VectorStringToChar(columns)); } -#ifndef ENABLE_ANDROID /// \brief Function to create a Rename Dataset /// \notes Renames the columns in the input dataset /// \param[in] input_columns List of the input columns to rename @@ -407,7 +392,6 @@ class Dataset : public std::enable_shared_from_this { return std::make_shared(shared_from_this(), VectorStringToChar(input_columns), VectorStringToChar(output_columns)); } -#endif /// \brief Function to create a RepeatDataset /// \notes Repeats this dataset count times. Repeat indefinitely if count is -1 /// \param[in] count Number of times the dataset should be repeated @@ -417,7 +401,6 @@ class Dataset : public std::enable_shared_from_this { std::shared_ptr Repeat(int32_t count = -1) { return std::make_shared(shared_from_this(), count); } -#ifndef ENABLE_ANDROID /// \brief Function to create a Shuffle Dataset /// \notes Randomly shuffles the rows of this dataset /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling @@ -449,7 +432,6 @@ class Dataset : public std::enable_shared_from_this { all_datasets.push_back(shared_from_this()); return std::make_shared(all_datasets); } -#endif std::shared_ptr IRNode() { return ir_node_; } @@ -602,7 +584,6 @@ class BatchDataset : public Dataset { ~BatchDataset() = default; }; -#ifndef ENABLE_ANDROID class BucketBatchByLengthDataset : public Dataset { public: BucketBatchByLengthDataset( @@ -626,7 +607,6 @@ class FilterDataset : public Dataset { const std::vector> &input_columns); ~FilterDataset() = default; }; -#endif class MapDataset : public Dataset { public: @@ -643,14 +623,12 @@ class ProjectDataset : public Dataset { ~ProjectDataset() = default; }; -#ifndef ENABLE_ANDROID class RenameDataset : public Dataset { public: RenameDataset(std::shared_ptr input, const std::vector> &input_columns, const std::vector> &output_columns); ~RenameDataset() = default; }; -#endif class RepeatDataset : public Dataset { public: @@ -664,7 +642,6 @@ class ShuffleDataset : public Dataset { ~ShuffleDataset() = default; }; -#ifndef ENABLE_ANDROID class SkipDataset : public Dataset { public: SkipDataset(std::shared_ptr input, int32_t count); @@ -682,7 +659,6 @@ class ZipDataset : public Dataset { explicit ZipDataset(const std::vector> &inputs); ~ZipDataset() = default; }; -#endif /// \brief Function to create a SchemaObj /// \param[in] schema_file Path of schema file @@ -762,7 +738,6 @@ inline std::shared_ptr Album(const std::string &dataset_dir, const VectorStringToChar(column_names), decode, sampler, cache); } -#ifndef ENABLE_ANDROID class CelebADataset : public Dataset { public: explicit CelebADataset(const std::vector &dataset_dir, const std::vector &usage, @@ -1375,7 +1350,6 @@ inline std::shared_ptr MindData(const std::vector return std::make_shared(VectorStringToChar(dataset_files), VectorStringToChar(columns_list), sampler, padded_sample, num_padded); } -#endif class MnistDataset : public Dataset { public: @@ -1427,7 +1401,6 @@ inline std::shared_ptr Mnist(const std::string &dataset_dir, const const std::shared_ptr &cache = nullptr) { return std::make_shared(StringToChar(dataset_dir), StringToChar(usage), sampler, cache); } -#ifndef ENABLE_ANDROID /// \brief Function to create a ConcatDataset /// \notes Reload "+" operator to concat two datasets @@ -1686,7 +1659,6 @@ inline std::shared_ptr CreateDatasetCache(session_id_type id, uint inline std::shared_ptr Zip(const std::vector> &datasets) { return std::make_shared(datasets); } -#endif } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc index e86a8c0fd26..6085c5fa0c9 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -608,7 +608,7 @@ Status Rotate(const std::shared_ptr &input, std::shared_ptr *out } else { // we resize here since the shape changes // create a new bounding box with the rotate - cv::Rect2f bbox = cv::RotatedRect(cv::Point2f(), input_img.size(), degree).boundingRect2f(); + cv::Rect2f bbox = cv::RotatedRect(pc, input_img.size(), degree).boundingRect2f(); rot.at(0, 2) += bbox.width / 2.0 - input_img.cols / 2.0; rot.at(1, 2) += bbox.height / 2.0 - input_img.rows / 2.0; // use memcpy and don't compute the new shape since openCV has a rounding problem diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 292de85c0f5..8729b190884 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3671,7 +3671,7 @@ class ManifestDataset(MappableDataset): decode (bool, optional): decode the images after reading (default=False). num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). When this argument is specified, `num_samples` reflects - the max sample number of per shard. + the max number of samples per shard. shard_id (int, optional): The shard ID within `num_shards` (default=None). This argument can only be specified when `num_shards` is also specified. cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index db2ecc8b2d9..f046bbe0faa 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -74,6 +74,7 @@ SET(DE_UT_SRCS image_process_test.cc interrupt_test.cc ir_callback_test.cc + ir_sampler_test.cc ir_tensor_op_fusion_pass_test.cc ir_tree_adapter_test.cc ir_vision_test.cc diff --git a/tests/ut/cpp/dataset/c_api_affine_test.cc b/tests/ut/cpp/dataset/c_api_affine_test.cc index e8f6147176e..9b572e5513b 100644 --- a/tests/ut/cpp/dataset/c_api_affine_test.cc +++ b/tests/ut/cpp/dataset/c_api_affine_test.cc @@ -57,10 +57,10 @@ TEST_F(MindDataTestPipeline, TestAffineAPI) { uint64_t i = 0; while (row.size() != 0) { i++; - // auto image = row["image"]; - // MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + EXPECT_EQ(row["image"].Shape().at(0), 256); iter->GetNextRow(&row); - // EXPECT_EQ(row["image"].Shape()[0], 256); } EXPECT_EQ(i, 15); diff --git a/tests/ut/cpp/dataset/c_api_samplers_test.cc b/tests/ut/cpp/dataset/c_api_samplers_test.cc index 2faa3b7b175..ce20a9bad61 100644 --- a/tests/ut/cpp/dataset/c_api_samplers_test.cc +++ b/tests/ut/cpp/dataset/c_api_samplers_test.cc @@ -88,96 +88,6 @@ TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) { iter->Stop(); } -TEST_F(MindDataTestPipeline, TestCalculateNumSamples) { - int64_t num_rows = 30; // dummy variable for number of rows in the dataset - std::shared_ptr sampl = std::make_shared(2, 1, false, 6, 1, -1, true); - EXPECT_NE(sampl, nullptr); - std::shared_ptr sampler_rt; - sampl->SamplerBuild(&sampler_rt); - EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6); - - sampl = std::make_shared(3, false, 0); - EXPECT_NE(sampl, nullptr); - sampl->SamplerBuild(&sampler_rt); - EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30); - - sampl = std::make_shared(false, 12); - EXPECT_NE(sampl, nullptr); - sampl->SamplerBuild(&sampler_rt); - EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); - - sampl = std::make_shared(0, 10); - EXPECT_NE(sampl, nullptr); - sampl->SamplerBuild(&sampler_rt); - EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10); - - std::vector weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1}; - sampl = std::make_shared(weights, 12); - EXPECT_NE(sampl, nullptr); - sampl->SamplerBuild(&sampler_rt); - EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); - - std::vector indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21}; - sampl = std::make_shared(indices, 11); - EXPECT_NE(sampl, nullptr); - sampl->SamplerBuild(&sampler_rt); - EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11); - - // Testing chains - // Parent and child have num_samples - std::shared_ptr sampl1 = std::make_shared(weights, 12); - EXPECT_NE(sampl1, nullptr); - std::shared_ptr sampler_rt1; - sampl1->SamplerBuild(&sampler_rt1); - - std::shared_ptr sampl2 = std::make_shared(0, 10); - EXPECT_NE(sampl2, nullptr); - std::shared_ptr sampler_rt2; - sampl2->SamplerBuild(&sampler_rt2); - sampler_rt2->AddChild(sampler_rt1); - EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10); - - // Parent doesn't have num_samples - std::shared_ptr sampl3 = std::make_shared(weights, 12); - EXPECT_NE(sampl3, nullptr); - std::shared_ptr sampler_rt3; - sampl3->SamplerBuild(&sampler_rt3); - - std::shared_ptr sampl4 = std::make_shared(indices, 0); - EXPECT_NE(sampl4, nullptr); - std::shared_ptr sampler_rt4; - sampl4->SamplerBuild(&sampler_rt4); - sampler_rt4->AddChild(sampler_rt3); - EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 11); - - // Child doesn't have num_samples - std::shared_ptr sampl5 = std::make_shared(false, 0); - EXPECT_NE(sampl5, nullptr); - std::shared_ptr sampler_rt5; - sampl5->SamplerBuild(&sampler_rt5); - - std::shared_ptr sampl6 = std::make_shared(3, false, 7); - EXPECT_NE(sampl6, nullptr); - std::shared_ptr sampler_rt6; - sampl6->SamplerBuild(&sampler_rt6); - sampler_rt6->AddChild(sampler_rt5); - EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7); -} - -TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) { - std::vector indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; - std::shared_ptr sampl1 = std::make_shared(indices, 0); - EXPECT_FALSE(indices.empty()); - std::shared_ptr sampler_rt = nullptr; - sampl1->SamplerBuild(&sampler_rt); - EXPECT_NE(sampler_rt, nullptr); - std::shared_ptr sampl2 = std::make_shared(std::move(indices), 0); - EXPECT_TRUE(indices.empty()); - std::shared_ptr sampler_rt2 = nullptr; - sampl2->SamplerBuild(&sampler_rt2); - EXPECT_NE(sampler_rt, nullptr); -} - TEST_F(MindDataTestPipeline, TestNoSamplerSuccess1) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNoSamplerSuccess1."; // Test building a dataset with no sampler provided (defaults to random sampler diff --git a/tests/ut/cpp/dataset/c_api_transforms_test.cc b/tests/ut/cpp/dataset/c_api_transforms_test.cc index 193df824ea0..e62e6ac33ec 100644 --- a/tests/ut/cpp/dataset/c_api_transforms_test.cc +++ b/tests/ut/cpp/dataset/c_api_transforms_test.cc @@ -37,8 +37,8 @@ TEST_F(MindDataTestPipeline, TestComposeSuccess) { EXPECT_NE(ds, nullptr); // Create objects for the tensor ops - auto decode_op(new vision::Decode()); - auto resize_op(new vision::Resize({777, 777})); + std::shared_ptr decode_op(new vision::Decode()); + std::shared_ptr resize_op(new vision::Resize({777, 777})); transforms::Compose compose({decode_op, resize_op}); // Create a Map operation on ds @@ -493,10 +493,10 @@ TEST_F(MindDataTestPipeline, TestRandomChoiceSuccess) { uint64_t i = 0; while (row.size() != 0) { i++; - // auto image = row["image"]; - // auto label = row["label"]; - // MS_LOG(INFO) << "Tensor image shape: " << image->shape(); - // MS_LOG(INFO) << "Label shape: " << label->shape(); + auto image = row["image"]; + auto label = row["label"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + MS_LOG(INFO) << "Label shape: " << label.Shape(); iter->GetNextRow(&row); } diff --git a/tests/ut/cpp/dataset/c_api_vision_soft_dvpp_test.cc b/tests/ut/cpp/dataset/c_api_vision_soft_dvpp_test.cc index 3aecad80850..795e9e3a6a7 100644 --- a/tests/ut/cpp/dataset/c_api_vision_soft_dvpp_test.cc +++ b/tests/ut/cpp/dataset/c_api_vision_soft_dvpp_test.cc @@ -56,9 +56,9 @@ TEST_F(MindDataTestPipeline, TestSoftDvppDecodeRandomCropResizeJpegSuccess1) { uint64_t i = 0; while (row.size() != 0) { i++; - // auto image = row["image"]; - // MS_LOG(INFO) << "Tensor image shape: " << image->shape(); - // EXPECT_EQ(image->shape()[0] == 500 && image->shape()[1] == 500, true); + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + EXPECT_EQ(image.Shape()[0] == 500 && image.Shape()[1] == 500, true); iter->GetNextRow(&row); } @@ -98,9 +98,9 @@ TEST_F(MindDataTestPipeline, TestSoftDvppDecodeRandomCropResizeJpegSuccess2) { uint64_t i = 0; while (row.size() != 0) { i++; - // auto image = row["image"]; - // MS_LOG(INFO) << "Tensor image shape: " << image->shape(); - // EXPECT_EQ(image->shape()[0] == 500 && image->shape()[1] == 600, true); + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + EXPECT_EQ(image.Shape()[0] == 500 && image.Shape()[1] == 600, true); iter->GetNextRow(&row); } @@ -142,7 +142,7 @@ TEST_F(MindDataTestPipeline, TestSoftDvppDecodeResizeJpegSuccess1) { uint64_t i = 0; while (row.size() != 0) { i++; - // auto image = row["image"]; + // std::shared_ptr image = row["image"]; // MS_LOG(INFO) << "Tensor image shape: " << image->shape(); iter->GetNextRow(&row); } @@ -180,8 +180,8 @@ TEST_F(MindDataTestPipeline, TestSoftDvppDecodeResizeJpegSuccess2) { uint64_t i = 0; while (row.size() != 0) { i++; - // auto image = row["image"]; - // MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); iter->GetNextRow(&row); } diff --git a/tests/ut/cpp/dataset/execute_test.cc b/tests/ut/cpp/dataset/execute_test.cc index 4d138d12df0..9bfffa4072a 100644 --- a/tests/ut/cpp/dataset/execute_test.cc +++ b/tests/ut/cpp/dataset/execute_test.cc @@ -112,6 +112,8 @@ TEST_F(MindDataTestExecute, TestTransformInput2) { MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformInput2."; // Test Execute with transform op input using API constructors, with std::shared_ptr de_tensor; @@ -171,6 +173,7 @@ TEST_F(MindDataTestExecute, TestTransformInput3) { TEST_F(MindDataTestExecute, TestTransformInputSequential) { MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformInputSequential."; // Test Execute with transform op input using API constructors, with auto pointers; + // Note that with auto and new, we have to explicitly delete the allocated object as shown below. // Apply 2 transformations sequentially, including single non-vector Transform op input // Read image, construct MSTensor from dataset tensor @@ -207,7 +210,7 @@ TEST_F(MindDataTestExecute, TestTransformInputSequential) { TEST_F(MindDataTestExecute, TestTransformDecodeResizeCenterCrop1) { MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformDecodeResizeCenterCrop1."; - // Test Execute with Decode, Resize and CenterCrop transform ops input using API constructors, with auto pointers + // Test Execute with Decode, Resize and CenterCrop transform ops input using API constructors, with shared pointers // Read image, construct MSTensor from dataset tensor std::shared_ptr de_tensor; diff --git a/tests/ut/cpp/dataset/ir_sampler_test.cc b/tests/ut/cpp/dataset/ir_sampler_test.cc new file mode 100644 index 00000000000..42f2e4359e4 --- /dev/null +++ b/tests/ut/cpp/dataset/ir_sampler_test.cc @@ -0,0 +1,116 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" +#include "minddata/dataset/core/tensor.h" + +using namespace mindspore::dataset; +using mindspore::dataset::Tensor; + +class MindDataTestIrSampler : public UT::DatasetOpTesting { + protected: +}; + +TEST_F(MindDataTestIrSampler, TestCalculateNumSamples) { + int64_t num_rows = 30; // dummy variable for number of rows in the dataset + std::shared_ptr sampl = std::make_shared(2, 1, false, 6, 1, -1, true); + EXPECT_NE(sampl, nullptr); + std::shared_ptr sampler_rt; + sampl->SamplerBuild(&sampler_rt); + EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6); + + sampl = std::make_shared(3, false, 0); + EXPECT_NE(sampl, nullptr); + sampl->SamplerBuild(&sampler_rt); + EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30); + + sampl = std::make_shared(false, 12); + EXPECT_NE(sampl, nullptr); + sampl->SamplerBuild(&sampler_rt); + EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); + + sampl = std::make_shared(0, 10); + EXPECT_NE(sampl, nullptr); + sampl->SamplerBuild(&sampler_rt); + EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10); + + std::vector weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1}; + sampl = std::make_shared(weights, 12); + EXPECT_NE(sampl, nullptr); + sampl->SamplerBuild(&sampler_rt); + EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); + + std::vector indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21}; + sampl = std::make_shared(indices, 11); + EXPECT_NE(sampl, nullptr); + sampl->SamplerBuild(&sampler_rt); + EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11); + + // Testing chains + // Parent and child have num_samples + std::shared_ptr sampl1 = std::make_shared(weights, 12); + EXPECT_NE(sampl1, nullptr); + std::shared_ptr sampler_rt1; + sampl1->SamplerBuild(&sampler_rt1); + + std::shared_ptr sampl2 = std::make_shared(0, 10); + EXPECT_NE(sampl2, nullptr); + std::shared_ptr sampler_rt2; + sampl2->SamplerBuild(&sampler_rt2); + sampler_rt2->AddChild(sampler_rt1); + EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10); + + // Parent doesn't have num_samples + std::shared_ptr sampl3 = std::make_shared(weights, 12); + EXPECT_NE(sampl3, nullptr); + std::shared_ptr sampler_rt3; + sampl3->SamplerBuild(&sampler_rt3); + + std::shared_ptr sampl4 = std::make_shared(indices, 0); + EXPECT_NE(sampl4, nullptr); + std::shared_ptr sampler_rt4; + sampl4->SamplerBuild(&sampler_rt4); + sampler_rt4->AddChild(sampler_rt3); + EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 11); + + // Child doesn't have num_samples + std::shared_ptr sampl5 = std::make_shared(false, 0); + EXPECT_NE(sampl5, nullptr); + std::shared_ptr sampler_rt5; + sampl5->SamplerBuild(&sampler_rt5); + + std::shared_ptr sampl6 = std::make_shared(3, false, 7); + EXPECT_NE(sampl6, nullptr); + std::shared_ptr sampler_rt6; + sampl6->SamplerBuild(&sampler_rt6); + sampler_rt6->AddChild(sampler_rt5); + EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7); +} + +TEST_F(MindDataTestIrSampler, TestSamplersMoveParameters) { + std::vector indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; + std::shared_ptr sampl1 = std::make_shared(indices, 0); + EXPECT_FALSE(indices.empty()); + std::shared_ptr sampler_rt = nullptr; + sampl1->SamplerBuild(&sampler_rt); + EXPECT_NE(sampler_rt, nullptr); + std::shared_ptr sampl2 = std::make_shared(std::move(indices), 0); + EXPECT_TRUE(indices.empty()); + std::shared_ptr sampler_rt2 = nullptr; + sampl2->SamplerBuild(&sampler_rt2); + EXPECT_NE(sampler_rt, nullptr); +} diff --git a/tests/ut/python/dataset/test_num_samples.py b/tests/ut/python/dataset/test_num_samples.py new file mode 100644 index 00000000000..69176d8ba20 --- /dev/null +++ b/tests/ut/python/dataset/test_num_samples.py @@ -0,0 +1,56 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import mindspore.dataset as ds +from mindspore import log as logger + + +def test_num_samples(): + manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" + num_samples = 1 + # sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1) + data1 = ds.ManifestDataset( + manifest_file, num_samples=num_samples, num_shards=3, shard_id=1 + ) + row_count = 0 + for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + row_count += 1 + assert row_count == 1 + + +def test_num_samples_tf(): + logger.info("test_tfrecord_read_all_dataset") + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json" + files = ["../data/dataset/testTFTestAllTypes/test.data"] + # here num samples indicate the rows per shard. Total rows in file = 12 + ds1 = ds.TFRecordDataset(files, schema_file, num_samples=2) + count = 0 + for _ in ds1.create_tuple_iterator(num_epochs=1): + count += 1 + assert count == 2 + + +def test_num_samples_image_folder(): + data_dir = "../data/dataset/testPK/data" + ds1 = ds.ImageFolderDataset(data_dir, num_samples=2, num_shards=2, shard_id=0) + count = 0 + for _ in ds1.create_tuple_iterator(num_epochs=1): + count += 1 + assert count == 2 + + +if __name__ == "__main__": + test_num_samples() + test_num_samples_tf() + test_num_samples_image_folder()