diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index a62994cb51c..5f61c86f06e 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -54,6 +54,7 @@ static std::unordered_map g_parse_op_func_ = {{kStorage, &D {kGenerator, &DEPipeline::ParseGeneratorOp}, {kTfReader, &DEPipeline::ParseTFReaderOp}, {kProject, &DEPipeline::ParseProjectOp}, + {kTake, &DEPipeline::ParseTakeOp}, {kImageFolder, &DEPipeline::ParseImageFolderOp}, {kMnist, &DEPipeline::ParseMnistOp}, {kManifest, &DEPipeline::ParseManifestOp}, @@ -650,7 +651,16 @@ Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr *ptr) { + if (args["count"].is_none()) { + std::string err_msg = "Error: count is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::shared_ptr op; + RETURN_IF_NOT_OK(TakeOp::Builder(ToInt(args["count"])).Build(&op)); + *ptr = op; + return Status::OK(); +} Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr *ptr) { std::shared_ptr builder = std::make_shared(); diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index 35276e5b745..6ff7bb091cd 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -116,7 +116,7 @@ class DEPipeline { Status ParseRenameOp(const py::dict &args, std::shared_ptr *ptr); - DsOpPtr ParseTakeOp(const py::dict &args) const; + Status ParseTakeOp(const py::dict &args, std::shared_ptr *ptr); Status ParseZipOp(const py::dict &args, std::shared_ptr *ptr); diff --git a/mindspore/ccsrc/dataset/core/client.h b/mindspore/ccsrc/dataset/core/client.h index b39ba3442be..b865c542604 100644 --- a/mindspore/ccsrc/dataset/core/client.h +++ b/mindspore/ccsrc/dataset/core/client.h @@ -38,6 +38,7 @@ #include "dataset/engine/datasetops/source/mindrecord_op.h" #include "dataset/engine/datasetops/source/storage_op.h" #include "dataset/engine/datasetops/source/tf_reader_op.h" +#include "dataset/engine/datasetops/take_op.h" #include "dataset/engine/datasetops/zip_op.h" #include "dataset/engine/execution_tree.h" #include "dataset/util/status.h" diff --git a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt index 9e511f78f4a..655a739ada7 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt @@ -5,13 +5,13 @@ add_library(engine-datasetops OBJECT parallel_op.cc pipeline_op.cc batch_op.cc - batch_op.cc device_queue_op.cc map_op.cc project_op.cc rename_op.cc repeat_op.cc skip_op.cc + take_op.cc shuffle_op.cc zip_op.cc ) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc index 5b0433b6c81..90c160b5bff 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc @@ -88,6 +88,10 @@ Status SkipOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t work // If buffer is none or the rows of buffer is 0, // then get a buffer from child. if (!buf || buf->NumRows() == 0) { + if (buf && buf->eof()) { + *p_buffer = std::move(buf); + return Status::OK(); + } RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc new file mode 100644 index 00000000000..d9625b6c26d --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc @@ -0,0 +1,146 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "common/utils.h" +#include "dataset/engine/data_buffer.h" +#include "dataset/engine/datasetops/take_op.h" +#include "dataset/engine/db_connector.h" +#include "dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {} + +Status TakeOp::Builder::SanityCheck() const { + if (build_max_takes_ <= 0) { + std::string err_msg("Take count must be greater than 0."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +// The builder "build" method creates the final object. +Status TakeOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_max_takes_); + return Status::OK(); +} + +// Constructor of the TakeOp. +TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {} + +// A print method typically used for debugging +void TakeOp::Print(std::ostream &out, bool show_all) const { + // Call base class printer first + PipelineOp::Print(out, show_all); + + // Then display our own stuff + out << "TakeOp:" + << "\nCurrent take count: " << take_count_ << "\nMax take count: " << max_takes_; +} + +// This function will be call muti times to returns the buffer, when meet required max take count or meet +// EOF buffer then this will stop. +Status TakeOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { + if (child_.empty()) { + RETURN_STATUS_UNEXPECTED("TakeOp can't be the leaf node."); + } + + std::unique_ptr buf; + + bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat); + if (take_count_ == max_takes_) { + if (state_ == OpState::kDeOpRunning) { + MS_LOG(INFO) << "meet max count and push-back eoe buffer."; + auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + *p_buffer = std::move(eoe_buffer); + state_ = OpState::kDeOpIdle; + + // Reset the count and drain + if (!last_repeat) { + take_count_ = 0; + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); + while (!buf->eoe() && !buf->eof()) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); + } + } + } else { + MS_LOG(INFO) << "meet max count and push-back eof buffer."; + auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + *p_buffer = std::move(eof_buffer); + take_count_ = 0; + } + return Status::OK(); + } + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); + // Loop until non EOE is received + if (buf->eoe()) { + take_count_ = 0; + *p_buffer = std::move(buf); + return Status::OK(); + } + + // Check if the last buf is next eof + if (buf->eof()) { + *p_buffer = std::move(buf); + return Status::OK(); + } + + // Get buffer and push back when take_count is still small + if (take_count_ < max_takes_) { + RETURN_IF_NOT_OK(FillBuffer(&buf, p_buffer)); + } + return Status::OK(); +} + +// Function FillBuffer mainly prepare the buffer for returning +Status TakeOp::FillBuffer(std::unique_ptr *buffer, std::unique_ptr *data_buffer) { + int32_t buffer_size = (*buffer)->NumRows(); + if (take_count_ + buffer_size < max_takes_) { + *data_buffer = std::move(*buffer); + take_count_ = take_count_ + buffer_size; + } else { + MS_LOG(INFO) << "In last buffer: Push one buffer."; + std::unique_ptr new_tensor_table = std::make_unique(); + while (take_count_ < max_takes_) { + TensorRow new_row; + RETURN_IF_NOT_OK((*buffer)->PopRow(&new_row)); + take_count_++; + new_tensor_table->push_back(new_row); + } + (*buffer)->set_tensor_table(std::move(new_tensor_table)); + *data_buffer = std::move(*buffer); + } + return Status::OK(); +} + +// Class functor operator () override. +// Most dataset ops operate by launching a thread (see ExecutionTree). +// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the +// functor since this op runs inlined inside another operator. The function is overloaded to +// ensure that it is not called by mistake (it will generate an error). +Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. TakeOp is an inlined operator."); } + +Status TakeOp::PrepareNodePostAction() { + RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); + tree_->AddToRepeatStack(shared_from_this()); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.h b/mindspore/ccsrc/dataset/engine/datasetops/take_op.h new file mode 100644 index 00000000000..02218cf610e --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/take_op.h @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ + +#include +#include +#include +#include "dataset/engine/datasetops/pipeline_op.h" + +namespace mindspore { +namespace dataset { +class TakeOp : public PipelineOp { + public: + // The nested builder class inside of the TakeOp is used to help manage all of the arguments + // for constructing it. This take op is very simple though, so this builder is really just + // provided for a consistent look and feel for creators of Dataset operators overall. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @param count - The number of takes to do + // @return This is a constructor. + explicit Builder(int32_t count); + + // Default destructor + ~Builder() = default; + + // The builder "build" method creates the final object. + // @return shared_ptr to the new StorageOp object + Status Build(std::shared_ptr *); + + private: + int32_t build_max_takes_; + + Status SanityCheck() const; + }; + + // Constructor of the TakeOp. + // @note The builder class should be used to call it + // @param count - The number of takes to do + explicit TakeOp(int32_t count); + + // Destructor + ~TakeOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param ro - reference to the TakeOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const TakeOp &ro) { + ro.Print(out, false); + return out; + } + + // Class functor operator () override. + // Most dataset ops operate by launching a thread (see ExecutionTree). + // However, the TakeOp is defined as a inlined operator, so it is invalid to launch the + // functor since this op runs inlined inside another operator. The function is overloaded to + // ensure that it is not called by mistake (it will generate an error). + // @return Status - The error code return + Status operator()() override; + + // Gets a buffer from the child node. The caller is typically our parent node. + // @note This function sets the `retryIfEoe` flag when popping from the child connector. This way, + // this function will retry to pop the connector again and will get the non-EOE buffer if any. + // @param p_buffer - output pointer to the buffer that it will fetch. + // @param worker_id - The worker id + // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. + // @return Status - The error code return + Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) override; + + // During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + Status PrepareNodePostAction() override; + + private: + int32_t max_takes_; // The number of takes that the user requested + int32_t take_count_; // A counter for the current number of executed takes + + Status FillBuffer(std::unique_ptr *buffer, std::unique_ptr *data_buffer); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 642e2beec8c..8de56a6dff2 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -36,7 +36,7 @@ from mindspore import log as logger from . import samplers from .iterators import DictIterator, TupleIterator from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \ - check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ + check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ check_zip_dataset, check_add_column from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist @@ -442,6 +442,33 @@ class Dataset: """ return SkipDataset(self, count) + @check_take + def take(self, count=-1): + """ + Takes at most given numbers of elements from the dataset. + + Note: + 1. If count is greater than the number of element in dataset or equal to -1, + all the element in dataset will be taken. + 2. The order of using take and batch effects. If take before batch operation, + then taken given number of rows, otherwise take given number of batches. + + Args: + count (int, optional): Number of elements to be taken from the dataset (default=-1). + + Returns: + TakeDataset, dataset taken. + + Examples: + >>> import mindspore.dataset as ds + >>> # data is an instance of Dataset object. + >>> # creates a dataset where the dataset including 50 elements. + >>> data = data.take(50) + """ + if count == -1: + return self + return TakeDataset(self, count) + @check_zip_dataset def zip(self, datasets): """ @@ -1100,6 +1127,7 @@ class RepeatDataset(DatasetOp): """ return self.count + class SkipDataset(DatasetOp): """ The result of applying Skip operator to the input Dataset. @@ -1134,6 +1162,41 @@ class SkipDataset(DatasetOp): output_size = child_size - self.count return output_size + +class TakeDataset(DatasetOp): + """ + The result of applying Take operator to the input Dataset. + + Args: + input_dataset (Dataset): Input Dataset to be taken element from. + count (int): Number of elements to be taken from the dataset. + """ + + def __init__(self, input_dataset, count): + super().__init__() + self.count = count + self.input.append(input_dataset) + input_dataset.output.append(self) + self._input_indexs = input_dataset.input_indexs + + def get_args(self): + args = super().get_args() + args["count"] = self.count + return args + + def get_dataset_size(self): + """ + Get the number of batches in an epoch. + + Return: + Number, number of batches. + """ + child_size = self.input[0].get_dataset_size() + if child_size < self.count: + return child_size + return self.count + + class ZipDataset(DatasetOp): """ The result of applying Zip operator to the input Dataset. diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index d670de508c4..3d6873d04c4 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -129,6 +129,8 @@ class Iterator: op_type = OpName.REPEAT elif isinstance(dataset, de.SkipDataset): op_type = OpName.SKIP + elif isinstance(dataset, de.TakeDataset): + op_type = OpName.TAKE elif isinstance(dataset, de.StorageDataset): op_type = OpName.STORAGE elif isinstance(dataset, de.ImageFolderDatasetV2): diff --git a/mindspore/dataset/engine/serializer_deserializer.py b/mindspore/dataset/engine/serializer_deserializer.py index a54a7a6b321..61417e4d52e 100644 --- a/mindspore/dataset/engine/serializer_deserializer.py +++ b/mindspore/dataset/engine/serializer_deserializer.py @@ -304,6 +304,9 @@ def create_node(node): elif dataset_op == 'SkipDataset': pyobj = de.Dataset().skip(node.get('count')) + elif dataset_op == 'TakeDataset': + pyobj = de.Dataset().take(node.get('count')) + elif dataset_op == 'MapDataset': tensor_ops = construct_tensor_ops(node.get('operations')) pyobj = de.Dataset().map(node.get('input_columns'), tensor_ops, node.get('output_columns'), diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 3502cbb2045..b74e913202f 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -602,7 +602,7 @@ def check_batch_size(batch_size): def check_count(count): check_type(count, 'count', int) if (count <= 0 and count != -1) or count > INT32_MAX: - raise ValueError("repeat count should be either -1 or positive integer.") + raise ValueError("count should be either -1 or positive integer.") def check_columns(columns, name): @@ -709,6 +709,7 @@ def check_repeat(method): return new_method + def check_skip(method): """check the input arguments of skip.""" @wraps(method) @@ -724,6 +725,21 @@ def check_skip(method): return new_method + +def check_take(method): + """check the input arguments of take.""" + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + count = param_dict.get('count') + check_count(count) + + return method(*args, **kwargs) + + return new_method + + def check_zip(method): """check the input arguments of zip.""" @wraps(method) @@ -759,6 +775,7 @@ def check_zip_dataset(method): return new_method + def check_rename(method): """check the input arguments of rename.""" @wraps(method) diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index db207363a8e..ae9c46e62c9 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -64,6 +64,7 @@ SET(DE_UT_SRCS voc_op_test.cc cifar_op_test.cc celeba_op_test.cc + take_op_test.cc ) add_executable(de_ut_tests ${DE_UT_SRCS}) diff --git a/tests/ut/cpp/dataset/take_op_test.cc b/tests/ut/cpp/dataset/take_op_test.cc new file mode 100644 index 00000000000..7f8508de20c --- /dev/null +++ b/tests/ut/cpp/dataset/take_op_test.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "common/common.h" +#include "common/utils.h" +#include "dataset/core/client.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" + +namespace common = mindspore::common; + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestTakeOp : public UT::DatasetOpTesting {}; + +TEST_F(MindDataTestTakeOp, TestTakeProject) { + // Start with an empty execution tree + auto my_tree = std::make_shared(); + + std::string dataset_path; + dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; + + // TFReaderOp + std::shared_ptr my_tfreader_op; + TFReaderOp::Builder builder; + builder.SetDatasetFilesList({dataset_path}) + .SetRowsPerBuffer(16) + .SetWorkerConnectorSize(16) + .SetNumWorkers(16); + std::unique_ptr schema = std::make_unique(); + schema->LoadSchemaFile(datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json", {}); + builder.SetDataSchema(std::move(schema)); + Status rc = builder.Build(&my_tfreader_op); + ASSERT_TRUE(rc.IsOk()); + + // TakeOp + std::shared_ptr my_take_op; + TakeOp::Builder builder_take(5); + rc = builder_take.Build(&my_take_op); + ASSERT_TRUE(rc.IsOk()); + + rc = my_tree->AssociateNode(my_tfreader_op); + ASSERT_TRUE(rc.IsOk()); + rc = my_tree->AssociateNode(my_take_op); + ASSERT_TRUE(rc.IsOk()); + + // Set children/root layout. + rc = my_take_op->AddChild(my_tfreader_op); + ASSERT_TRUE(rc.IsOk()); + rc = my_tree->AssignRoot(my_take_op); + ASSERT_TRUE(rc.IsOk()); + + MS_LOG(INFO) << "Launching tree and begin iteration."; + rc = my_tree->Prepare(); + + ASSERT_TRUE(rc.IsOk()); + + rc = my_tree->Launch(); + ASSERT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(my_tree); + TensorRow tensor_list; + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + + int row_count = 0; + while (!tensor_list.empty()) { + MS_LOG(INFO) << "Row display for row #: " << row_count << "."; + + // Display the tensor by calling the printer on it + for (int i = 0; i < tensor_list.size(); i++) { + std::ostringstream ss; + ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; + MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; + } + + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + row_count++; + } + + ASSERT_EQ(row_count, 5); +} diff --git a/tests/ut/python/dataset/test_take.py b/tests/ut/python/dataset/test_take.py new file mode 100644 index 00000000000..ed71f67e26c --- /dev/null +++ b/tests/ut/python/dataset/test_take.py @@ -0,0 +1,317 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as vision +from mindspore import log as logger +import numpy as np + + +# In generator dataset: Number of rows is 3, its value is 0, 1, 2 +def generator(): + for i in range(3): + yield np.array([i]), + + +# In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10 +def generator_10(): + for i in range(10): + yield np.array([i]), + + +def test_take_01(): + """ + Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof + """ + logger.info("test_take_01") + data1 = ds.GeneratorDataset(generator, ["data"]) + + data1 = data1.take(1) + data1 = data1.repeat(2) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert 0 == d[0][0] + + assert sum([1 for _ in data1]) == 2 + + +def test_take_02(): + """ + Test take: origin there are 3 row, and take 2 row, in this case: will meet eoe + """ + logger.info("test_take_02") + data1 = ds.GeneratorDataset(generator, ["data"]) + + data1 = data1.take(2) + data1 = data1.repeat(2) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert i % 2 == d[0][0] + + assert sum([1 for _ in data1]) == 4 + + +def test_take_03(): + """ + Test take: origin there are 3 row, and take 3 row, in this case: will meet eoe and eof + """ + logger.info("test_take_03") + data1 = ds.GeneratorDataset(generator, ["data"]) + + data1 = data1.take(3) + data1 = data1.repeat(2) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert i % 3 == d[0][0] + + assert sum([1 for _ in data1]) == 6 + + +def test_take_04(): + """ + Test take: origin there are 3 row, and take 4 row, this is more than the total rows + """ + logger.info("test_take_04") + data1 = ds.GeneratorDataset(generator, ["data"]) + + data1 = data1.take(4) + data1 = data1.repeat(2) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert i % 3 == d[0][0] + + assert sum([1 for _ in data1]) == 6 + + +def test_take_05(): + """ + Test take: there is no repeat op + """ + logger.info("test_take_05") + data1 = ds.GeneratorDataset(generator, ["data"]) + + data1 = data1.take(2) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert i == d[0][0] + + assert sum([1 for _ in data1]) == 2 + + +def test_take_06(): + """ + Test take: repeat is before take + """ + logger.info("test_take_06") + data1 = ds.GeneratorDataset(generator, ["data"]) + + data1 = data1.repeat(2) + data1 = data1.take(4) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert i % 3 == d[0][0] + + assert sum([1 for _ in data1]) == 4 + + +def test_take_07(): + """ + Test take: take is before batch, that mean take(N), N refer to rows num + """ + logger.info("test_take_07") + data1 = ds.GeneratorDataset(generator, ["data"]) + + data1 = data1.take(2) + data1 = data1.batch(2) + assert sum([1 for _ in data1]) == 1 + + +def test_take_08(): + """ + Test take: take is after batch, that mean take(N), N refer to batches num + """ + logger.info("test_take_08") + data1 = ds.GeneratorDataset(generator, ["data"]) + + data1 = data1.batch(2) + data1 = data1.take(2) + assert sum([1 for _ in data1]) == 2 + + +def test_take_09(): + """ + Test take: repeat count is -1, and read the whole dataset, take after repeat + """ + logger.info("test_take_09") + data1 = ds.GeneratorDataset(generator, ["data"]) + + data1 = data1.repeat(2) + data1 = data1.take(-1) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert i % 3 == d[0][0] + + assert sum([1 for _ in data1]) == 6 + + +def test_take_10(): + """ + Test take: repeat count is -1, and read the whole dataset, take before repeat + """ + logger.info("test_take_10") + data1 = ds.GeneratorDataset(generator, ["data"]) + + data1 = data1.take(-1) + data1 = data1.repeat(2) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert i % 3 == d[0][0] + + assert sum([1 for _ in data1]) == 6 + + +def test_take_11(): + """ + Test take: batch first, then do repeat and take operation + """ + logger.info("test_take_11") + data1 = ds.GeneratorDataset(generator, ["data"]) + + data1 = data1.batch(2) + data1 = data1.repeat(2) + data1 = data1.take(-1) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert 2 * (i % 2) == d[0][0] + + assert sum([1 for _ in data1]) == 4 + + +def test_take_12(): + """ + Test take: take first, then do batch and repeat operation + """ + logger.info("test_take_12") + data1 = ds.GeneratorDataset(generator, ["data"]) + + data1 = data1.take(2) + data1 = data1.batch(2) + data1 = data1.repeat(2) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert 0 == d[0][0] + + assert sum([1 for _ in data1]) == 2 + + +def test_take_13(): + """ + Test take: skip first, then do take, batch and repeat operation + """ + logger.info("test_take_13") + data1 = ds.GeneratorDataset(generator, ["data"]) + + data1 = data1.skip(2) + data1 = data1.take(-1) + data1 = data1.batch(2) + data1 = data1.repeat(2) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert 2 == d[0][0] + + assert sum([1 for _ in data1]) == 2 + + +def test_take_14(): + """ + Test take: take first, then do batch, skip and repeat operation + """ + logger.info("test_take_14") + data1 = ds.GeneratorDataset(generator, ["data"]) + + data1 = data1.take(-1) + data1 = data1.batch(2) + data1 = data1.skip(1) + data1 = data1.repeat(2) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert 2 == d[0][0] + + assert sum([1 for _ in data1]) == 2 + + +def test_take_15(): + """ + Test take: large amount data, take a part, then do skip operation + """ + logger.info("test_take_15") + data1 = ds.GeneratorDataset(generator_10, ["data"]) + + data1 = data1.take(6) + data1 = data1.skip(2) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert (i + 2) == d[0][0] + + assert sum([1 for _ in data1]) == 4 + + +def test_take_16(): + """ + Test take: large amount data, skip a part, then do take operation + """ + logger.info("test_take_16") + data1 = ds.GeneratorDataset(generator_10, ["data"]) + + data1 = data1.skip(3) + data1 = data1.take(5) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert (i + 3) == d[0][0] + + assert sum([1 for _ in data1]) == 5 + + +if __name__ == '__main__': + test_take_01() + test_take_02() + test_take_03() + test_take_04() + test_take_05() + test_take_06() + test_take_07() + test_take_08() + test_take_09() + test_take_10() + test_take_11() + test_take_12() + test_take_13() + test_take_14() + test_take_15() + test_take_16() + logger.info('== test take operation finished ==') \ No newline at end of file