diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index 27e5f8d2b1c..26f58fbde73 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -53,6 +53,7 @@ static std::unordered_map g_parse_op_func_ = {{kStorage, &D {kRepeat, &DEPipeline::ParseRepeatOp}, {kSkip, &DEPipeline::ParseSkipOp}, {kZip, &DEPipeline::ParseZipOp}, + {kConcat, &DEPipeline::ParseConcatOp}, {kRename, &DEPipeline::ParseRenameOp}, {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, {kGenerator, &DEPipeline::ParseGeneratorOp}, @@ -757,6 +758,14 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr * return Status::OK(); } +Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr *ptr) { + std::shared_ptr builder = std::make_shared(); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} + Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr *ptr) { // Required arguments std::shared_ptr builder = std::make_shared(); diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index 699348f1578..4ecfb080c1b 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -46,6 +46,7 @@ enum OpName { kSkip, kTake, kZip, + kConcat, kMap, kFilter, kDeviceQueue, @@ -127,6 +128,8 @@ class DEPipeline { Status ParseZipOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseConcatOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *ptr); Status ParseTFReaderOp(const py::dict &args, std::shared_ptr *ptr); diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index efb04a3bd04..8a5d7651fb5 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -468,6 +468,7 @@ PYBIND11_MODULE(_c_dataengine, m) { .value("SKIP", OpName::kSkip) .value("TAKE", OpName::kTake) .value("ZIP", OpName::kZip) + .value("CONCAT", OpName::kConcat) .value("MAP", OpName::kMap) .value("FILTER", OpName::kFilter) .value("DEVICEQUEUE", OpName::kDeviceQueue) diff --git a/mindspore/ccsrc/dataset/core/client.h b/mindspore/ccsrc/dataset/core/client.h index 40de887aea9..aa5e85f7de9 100644 --- a/mindspore/ccsrc/dataset/core/client.h +++ b/mindspore/ccsrc/dataset/core/client.h @@ -42,6 +42,7 @@ #include "dataset/engine/datasetops/source/tf_reader_op.h" #include "dataset/engine/datasetops/take_op.h" #include "dataset/engine/datasetops/zip_op.h" +#include "dataset/engine/datasetops/concat_op.h" #include "dataset/engine/execution_tree.h" #include "dataset/util/status.h" diff --git a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt index dedba4a4e7b..70065df5f48 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt @@ -17,6 +17,7 @@ add_library(engine-datasetops OBJECT take_op.cc shuffle_op.cc zip_op.cc + concat_op.cc filter_op.cc ) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc new file mode 100644 index 00000000000..eb6401409ac --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc @@ -0,0 +1,145 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "common/utils.h" +#include "dataset/core/config_manager.h" +#include "dataset/engine/data_buffer.h" +#include "dataset/engine/datasetops/concat_op.h" +#include "dataset/engine/db_connector.h" +#include "dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +ConcatOp::Builder::Builder() { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +// The builder "build" method creates the final object. +Status ConcatOp::Builder::Build(std::shared_ptr *ptr) { + *ptr = std::make_shared(builder_op_connector_size_); + return Status::OK(); +} + +// Constructor of the ConcatOp. +ConcatOp::ConcatOp(int32_t op_connector_size) : PipelineOp(op_connector_size), children_num_(0) {} + +// A function that prints info about the Operator +void ConcatOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this is summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nDatasets: " << children_num_ << "\n\n"; + } +} + +// Main entry point for Concat +Status ConcatOp::operator()() { + // The children_num_ parameter needs to be put here + children_num_ = static_cast(child_.size()); + + TaskManager::FindMe()->Post(); + std::unique_ptr buf; + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); + + // Obtain columns_name_id_map from child_[0] + column_name_id_map_ = child_[0]->column_name_id_map(); + if (column_name_id_map_.empty()) { + RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!"); + } + + int eof_count = 0; + while (eof_count != children_num_) { + for (int i = 0; i < children_num_; i++) { + // 1. Throw the eof buffer when meet it + if (buf->eof() || buf->eoe()) { + RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); + } + // 2. Do varification as for column name, column data type and rank of column data + RETURN_IF_NOT_OK(Verify(i, buf)); + + // 3. Put the data into output_connector + while (!buf->eoe() && !buf->eof()) { + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); + RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); + } + + // 4. Throw the eoe buffer when meet it + if (buf->eoe() && (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat))) { + RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); + } + // 5. Add eoe buffer after get buffer from all child + if (i == (children_num_ - 1)) { + auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + } + if (buf->eof()) { + eof_count++; + } + } + } + // 6. Add eof buffer in the end manually + MS_LOG(DEBUG) << "Add the eof buffer manualy in the end."; + auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + + return Status::OK(); +} + +Status ConcatOp::Verify(int32_t id, const std::unique_ptr &buf) { + TensorRow new_row; + buf->GetRow(0, &new_row); + + if (id == 0) { + // Obtain the column name, data type and data rank in child[0] + column_name_id_ = child_[id]->column_name_id_map(); + for (auto item : new_row) { + data_type_.push_back(item->type()); + data_rank_.push_back(item->Rank()); + } + } else { + // Compare the column name, data type and data rank with these in child[0] + if (child_[id]->column_name_id_map() != column_name_id_) { + RETURN_STATUS_UNEXPECTED("The column name or column order is not the same with previous dataset."); + } + int32_t index = 0; + for (auto item : new_row) { + if ((item->type() != data_type_[index]) || item->Rank() != data_rank_[index++]) { + RETURN_STATUS_UNEXPECTED("The data type or data rank is not the same with previous dataset."); + } + } + } + return Status::OK(); +} + +Status ConcatOp::PrepareNodePostAction() { + RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); + tree_->AddToRepeatStack(shared_from_this()); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.h b/mindspore/ccsrc/dataset/engine/datasetops/concat_op.h new file mode 100644 index 00000000000..9afadab39a2 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/concat_op.h @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ + +#include +#include +#include +#include +#include "dataset/engine/datasetops/pipeline_op.h" + +namespace mindspore { +namespace dataset { +class ConcatOp : public PipelineOp { + public: + // The nested builder class inside of the ConcatOp is used to help manage all of the arguments + // for constructing it. This Concat op is very simple though, so this builder is really just + // provided for a consistent look and feel for creators of Dataset operators overall. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // The builder "build" method creates the final object. + // @return shared_ptr to the new StorageOp object + Status Build(std::shared_ptr *); + + private: + int32_t builder_op_connector_size_; + }; + + // Constructor of the ConcatOp. + // @note The builder class should be used to call it + // @param op_connector_size - connector size + explicit ConcatOp(int32_t op_connector_size); + + // Destructor + ~ConcatOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param ro - reference to the ConcatOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const ConcatOp &ro) { + ro.Print(out, false); + return out; + } + + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + Status PrepareNodePostAction() override; + + private: + Status Verify(int32_t id, const std::unique_ptr &buf); + + int32_t children_num_; // The num of child of parent node. + std::unordered_map column_name_id_; // Mapping between col index and col name + std::vector data_type_; + std::vector data_rank_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 39aafeda095..7fb0214e381 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -44,7 +44,7 @@ from .validators import check, check_batch, check_shuffle, check_map, check_filt check_rename, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ - check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset + check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -147,6 +147,9 @@ class Dataset: self._repeat_count = None self._sync = False + def __add__(self, datasets): + return self.concat(datasets) + def get_args(self): """ Returns attributes (member variables) related to the current class. @@ -560,6 +563,37 @@ class Dataset: raise TypeError("The zip function %s type error!" % (datasets)) return ZipDataset(datasets) + @check_concat + def concat(self, datasets): + """ + Concat the datasets in the input list of datasets, supported using "+" to reload concat operation. + + Note: + The column nameļ¼Œcolumn data type and rank of column data should be the same in input datasets. + + Args: + datasets (list or class Dataset): A list of datasets or a single class Dataset + to be concated together with this dataset. + + Returns: + ConcatDataset, dataset concated. + + Examples: + >>> import mindspore.dataset as ds + >>> # ds1 and ds2 are instances of Dataset object + >>> # creates a dataset by concating ds1 and ds2 with "+" operation + >>> data1 = ds1 + ds2 + >>> # creates a dataset by concating ds1 and ds2 with concat operation + >>> data1 = ds1.concat(ds2) + """ + if isinstance(datasets, Dataset): + datasets = [self] + [datasets] + elif isinstance(datasets, list): + datasets = [self] + datasets + else: + raise TypeError("The concat_dataset function %s type error!" % (datasets)) + return ConcatDataset(datasets) + @check_rename def rename(self, input_columns, output_columns): """ @@ -1658,6 +1692,39 @@ class ZipDataset(DatasetOp): return args +class ConcatDataset(DatasetOp): + """ + The result of applying concat dataset operator to the input Dataset. + + Args: + datasets (list): A list of datasets to be concated together. + + Raises: + TypeError: If dataset is not an instance of Dataset. + """ + + def __init__(self, datasets): + super().__init__() + for dataset in datasets: + if not isinstance(dataset, Dataset): + raise TypeError("The parameter %s of concat has type error!" % (dataset)) + self.datasets = datasets + for data in datasets: + self.input.append(data) + data.output.append(self) + + def get_dataset_size(self): + """ + Get the number of batches in an epoch. + + Return: + Number, number of batches. + """ + children_sizes = [c.get_dataset_size() for c in self.input] + dataset_size = np.sum(children_sizes) + return dataset_size + + class RenameDataset(DatasetOp): """ The result of applying Rename operator to the input Dataset. diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index e32c188d005..b778bdacae7 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -156,6 +156,8 @@ class Iterator: op_type = OpName.BARRIER elif isinstance(dataset, de.ZipDataset): op_type = OpName.ZIP + elif isinstance(dataset, de.ConcatDataset): + op_type = OpName.CONCAT elif isinstance(dataset, de.MapDataset): op_type = OpName.MAP elif isinstance(dataset, de.FilterDataset): diff --git a/mindspore/dataset/engine/serializer_deserializer.py b/mindspore/dataset/engine/serializer_deserializer.py index f588d572bba..408a24e16c7 100644 --- a/mindspore/dataset/engine/serializer_deserializer.py +++ b/mindspore/dataset/engine/serializer_deserializer.py @@ -335,6 +335,10 @@ def create_node(node): # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller. pyobj = de.ZipDataset((de.Dataset(), de.Dataset())) + elif dataset_op == 'ConcatDataset': + # Create ConcatDataset instance, giving dummy input dataset that will be overrided in the caller. + pyobj = de.ConcatDataset((de.Dataset(), de.Dataset())) + elif dataset_op == 'RenameDataset': pyobj = de.Dataset().rename(node['input_columns'], node['output_columns']) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 2e69e2f0ec9..c15ac136518 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -875,6 +875,26 @@ def check_zip_dataset(method): return new_method +def check_concat(method): + """check the input arguments of concat_dataset method in `Dataset`.""" + + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + # check datasets; required argument + ds = param_dict.get("datasets") + if ds is None: + raise ValueError("datasets is not provided.") + + if not isinstance(ds, (list, datasets.Dataset)): + raise ValueError("datasets is not list or of type Dataset.") + + return method(*args, **kwargs) + + return new_method + + def check_rename(method): """check the input arguments of rename.""" diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 732f71692b8..3f7c194b196 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -66,6 +66,7 @@ SET(DE_UT_SRCS take_op_test.cc text_file_op_test.cc filter_op_test.cc + concat_op_test.cc ) add_executable(de_ut_tests ${DE_UT_SRCS}) diff --git a/tests/ut/cpp/dataset/concat_op_test.cc b/tests/ut/cpp/dataset/concat_op_test.cc new file mode 100644 index 00000000000..70d0268ec75 --- /dev/null +++ b/tests/ut/cpp/dataset/concat_op_test.cc @@ -0,0 +1,125 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "common/common.h" +#include "common/utils.h" +#include "dataset/core/client.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" + +namespace common = mindspore::common; + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestConcatOp : public UT::DatasetOpTesting {}; + + +TEST_F(MindDataTestConcatOp, TestConcatProject) { +/* Tree: + * + * OpId(2) ConcatOp + * / \ + * OpId(0) TFReaderOp OpId(1) TFReaderOp + * + * Start with an empty execution tree +*/ + MS_LOG(INFO) << "UT test TestConcatProject."; + auto my_tree = std::make_shared(); + + std::string dataset_path; + dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; + + // TFReaderOp1 + std::shared_ptr my_tfreader_op1; + TFReaderOp::Builder builder1; + builder1.SetDatasetFilesList({dataset_path}) + .SetRowsPerBuffer(16) + .SetWorkerConnectorSize(16) + .SetNumWorkers(16); + std::unique_ptr schema1 = std::make_unique(); + schema1->LoadSchemaFile(datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json", {}); + builder1.SetDataSchema(std::move(schema1)); + Status rc = builder1.Build(&my_tfreader_op1); + ASSERT_TRUE(rc.IsOk()); + rc = my_tree->AssociateNode(my_tfreader_op1); + ASSERT_TRUE(rc.IsOk()); + + // TFReaderOp2 + std::shared_ptr my_tfreader_op2; + TFReaderOp::Builder builder2; + builder2.SetDatasetFilesList({dataset_path}) + .SetRowsPerBuffer(16) + .SetWorkerConnectorSize(16) + .SetNumWorkers(16); + std::unique_ptr schema2 = std::make_unique(); + schema2->LoadSchemaFile(datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json", {}); + builder2.SetDataSchema(std::move(schema2)); + rc = builder2.Build(&my_tfreader_op2); + ASSERT_TRUE(rc.IsOk()); + rc = my_tree->AssociateNode(my_tfreader_op2); + ASSERT_TRUE(rc.IsOk()); + + // Creating ConcatOp + std::shared_ptr concat_op; + rc = ConcatOp::Builder().Build(&concat_op); + EXPECT_TRUE(rc.IsOk()); + + rc = my_tree->AssociateNode(concat_op); + EXPECT_TRUE(rc.IsOk()); + rc = concat_op->AddChild(std::move(my_tfreader_op1)); + EXPECT_TRUE(rc.IsOk()); + rc = concat_op->AddChild(std::move(my_tfreader_op2)); + EXPECT_TRUE(rc.IsOk()); + rc = my_tree->AssignRoot(concat_op); + EXPECT_TRUE(rc.IsOk()); + rc = my_tree->Prepare(); + EXPECT_TRUE(rc.IsOk()); + + // Launch the tree execution to kick off threads and start running the pipeline + MS_LOG(INFO) << "Launching my tree."; + rc = my_tree->Launch(); + EXPECT_TRUE(rc.IsOk()); + + // Simulate a parse of data from our pipeline. + std::shared_ptr rootNode = my_tree->root(); + + DatasetIterator di(my_tree); + TensorRow tensor_list; + rc = di.FetchNextTensorRow(&tensor_list); + EXPECT_TRUE(rc.IsOk()); + + int row_count = 0; + while (!tensor_list.empty()) { + MS_LOG(INFO) << "Row display for row #: " << row_count << "."; + + // Display the tensor by calling the printer on it + for (int i = 0; i < tensor_list.size(); i++) { + std::ostringstream ss; + ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; + MS_LOG(INFO) << "Tensor print: " << common::SafeCStr(ss.str()) << "."; + } + rc = di.FetchNextTensorRow(&tensor_list); + EXPECT_TRUE(rc.IsOk()); + row_count++; + } + ASSERT_EQ(row_count, 24); // Should be 24 rows fetched +} \ No newline at end of file diff --git a/tests/ut/python/dataset/test_concat.py b/tests/ut/python/dataset/test_concat.py new file mode 100644 index 00000000000..fad1288a041 --- /dev/null +++ b/tests/ut/python/dataset/test_concat.py @@ -0,0 +1,377 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.py_transforms as F +import mindspore.dataset.transforms.c_transforms as C +import mindspore.common.dtype as mstype +from mindspore import log as logger +import numpy as np + + +# In generator dataset: Number of rows is 3, its value is 0, 1, 2 +def generator(): + for i in range(3): + yield np.array([i]), + + +# In generator_10 dataset: Number of rows is 7, its value is 3, 4, 5 ... 10 +def generator_10(): + for i in range(3, 10): + yield np.array([i]), + +# In generator_20 dataset: Number of rows is 10, its value is 10, 11, 12 ... 20 +def generator_20(): + for i in range(10, 20): + yield np.array([i]), + + +def test_concat_01(): + """ + Test concat: test concat 2 datasets that have the same column name and data type + """ + logger.info("test_concat_01") + data1 = ds.GeneratorDataset(generator, ["col1"]) + data2 = ds.GeneratorDataset(generator_10, ["col1"]) + + data3 = data1 + data2 + + # Here i refers to index, d refers to data element + for i, d in enumerate(data3): + logger.info("data: %i", d[0][0]) + assert i == d[0][0] + + assert sum([1 for _ in data3]) == 10 + + +def test_concat_02(): + """ + Test concat: test concat 2 datasets using concat operation not "+" operation + """ + logger.info("test_concat_02") + data1 = ds.GeneratorDataset(generator, ["col1"]) + data2 = ds.GeneratorDataset(generator_10, ["col1"]) + + data3 = data1.concat(data2) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data3): + logger.info("data: %i", d[0][0]) + assert i == d[0][0] + + assert sum([1 for _ in data3]) == 10 + + +def test_concat_03(): + """ + Test concat: test concat dataset that has different column + """ + logger.info("test_concat_03") + data1 = ds.GeneratorDataset(generator, ["col1"]) + data2 = ds.GeneratorDataset(generator_10, ["col2"]) + + data3 = data1 + data2 + + try: + for i, d in enumerate(data3): + pass + assert False + except RuntimeError: + pass + + +def test_concat_04(): + """ + Test concat: test concat dataset that has different rank + """ + logger.info("test_concat_04") + data1 = ds.GeneratorDataset(generator, ["col1"]) + data2 = ds.GeneratorDataset(generator_10, ["col2"]) + data2 = data2.batch(3) + + data3 = data1 + data2 + + try: + for i, d in enumerate(data3): + pass + assert False + except RuntimeError: + pass + + +def test_concat_05(): + """ + Test concat: test concat dataset that has different data type + """ + logger.info("test_concat_05") + data1 = ds.GeneratorDataset(generator, ["col1"]) + data2 = ds.GeneratorDataset(generator_10, ["col1"]) + + type_cast_op = C.TypeCast(mstype.float32) + data1 = data1.map(input_columns=["col1"], operations=type_cast_op) + + data3 = data1 + data2 + + try: + for i, d in enumerate(data3): + pass + assert False + except RuntimeError: + pass + + +def test_concat_06(): + """ + Test concat: test concat muti datasets in one time + """ + logger.info("test_concat_06") + data1 = ds.GeneratorDataset(generator, ["col1"]) + data2 = ds.GeneratorDataset(generator_10, ["col1"]) + data3 = ds.GeneratorDataset(generator_20, ["col1"]) + + dataset = data1 + data2 + data3 + + # Here i refers to index, d refers to data element + for i, d in enumerate(dataset): + logger.info("data: %i", d[0][0]) + assert i == d[0][0] + + assert sum([1 for _ in dataset]) == 20 + + +def test_concat_07(): + """ + Test concat: test concat one dataset with multi datasets (datasets list) + """ + logger.info("test_concat_07") + data1 = ds.GeneratorDataset(generator, ["col1"]) + data2 = ds.GeneratorDataset(generator_10, ["col1"]) + data3 = ds.GeneratorDataset(generator_20, ["col1"]) + + dataset = [data2] + [data3] + data4 = data1 + dataset + + # Here i refers to index, d refers to data element + for i, d in enumerate(data4): + logger.info("data: %i", d[0][0]) + assert i == d[0][0] + + assert sum([1 for _ in data4]) == 20 + + +def test_concat_08(): + """ + Test concat: test concat 2 datasets, and then repeat + """ + logger.info("test_concat_08") + data1 = ds.GeneratorDataset(generator, ["col1"]) + data2 = ds.GeneratorDataset(generator_10, ["col1"]) + + data3 = data1 + data2 + data3 = data3.repeat(2) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data3): + logger.info("data: %i", d[0][0]) + assert i % 10 == d[0][0] + + assert sum([1 for _ in data3]) == 20 + + +def test_concat_09(): + """ + Test concat: test concat 2 datasets, both of them have been repeat before + """ + logger.info("test_concat_09") + data1 = ds.GeneratorDataset(generator, ["col1"]) + data2 = ds.GeneratorDataset(generator_10, ["col1"]) + + data1 = data1.repeat(2) + data2 = data2.repeat(2) + data3 = data1 + data2 + + res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8, 9] + # Here i refers to index, d refers to data element + for i, d in enumerate(data3): + logger.info("data: %i", d[0][0]) + assert res[i] == d[0][0] + + assert sum([1 for _ in data3]) == 20 + + +def test_concat_10(): + """ + Test concat: test concat 2 datasets, one of them have repeat before + """ + logger.info("test_concat_10") + data1 = ds.GeneratorDataset(generator, ["col1"]) + data2 = ds.GeneratorDataset(generator_10, ["col1"]) + + data1 = data1.repeat(2) + data3 = data1 + data2 + + res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + # Here i refers to index, d refers to data element + for i, d in enumerate(data3): + logger.info("data: %i", d[0][0]) + assert res[i] == d[0][0] + + assert sum([1 for _ in data3]) == 13 + + +def test_concat_11(): + """ + Test concat: test dataset batch then concat + """ + logger.info("test_concat_11") + data1 = ds.GeneratorDataset(generator, ["col1"]) + data2 = ds.GeneratorDataset(generator_20, ["col1"]) + + data1 = data1.batch(3) + data2 = data2.batch(5) + + data3 = data1 + data2 + res = [0, 10, 15, 20] + + # Here i refers to index, d refers to data element + for i, d in enumerate(data3): + logger.info("data: %i", d[0][0]) + assert res[i] == d[0][0] + + assert sum([1 for _ in data3]) == 3 + + +def test_concat_12(): + """ + Test concat: test dataset concat then shuffle + """ + logger.info("test_concat_12") + data1 = ds.GeneratorDataset(generator, ["col1"]) + data2 = ds.GeneratorDataset(generator_10, ["col1"]) + + data1.set_dataset_size(3) + data2.set_dataset_size(7) + + data3 = data1 + data2 + res = [8, 6, 2, 5, 0, 4, 9, 3, 7, 1] + + ds.config.set_seed(1) + assert data3.get_dataset_size() == 10 + data3 = data3.shuffle(buffer_size=10) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data3): + logger.info("data: %i", d[0][0]) + assert res[i] == d[0][0] + + assert sum([1 for _ in data3]) == 10 + + +def test_concat_13(): + """ + Test concat: test dataset batch then shuffle and concat + """ + logger.info("test_concat_13") + data1 = ds.GeneratorDataset(generator, ["col1"]) + data2 = ds.GeneratorDataset(generator_20, ["col1"]) + + data1.set_dataset_size(3) + data2.set_dataset_size(10) + + data1 = data1.batch(3) + data2 = data2.batch(5) + + data3 = data1 + data2 + res = [15, 0, 10] + + ds.config.set_seed(1) + assert data3.get_dataset_size() == 3 + + data3 = data3.shuffle(buffer_size=int(data3.get_dataset_size())) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data3): + logger.info("data: %i", d[0][0]) + assert res[i] == d[0][0] + + assert sum([1 for _ in data3]) == 3 + + +def test_concat_14(): + """ + Test concat: create dataset with different dataset folder, and do diffrent operation then concat + """ + logger.info("test_concat_14") + DATA_DIR = "../data/dataset/testPK/data" + DATA_DIR2 = "../data/dataset/testImageNetData/train/" + + data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=3) + data2 = ds.ImageFolderDatasetV2(DATA_DIR2, num_samples=2) + + transforms1 = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + data1 = data1.map(input_columns=["image"], operations=transforms1()) + data2 = data2.map(input_columns=["image"], operations=transforms1()) + data3 = data1 + data2 + + expected, output = [], [] + for d in data1: + expected.append(d[0]) + for d in data2: + expected.append(d[0]) + for d in data3: + output.append(d[0]) + + assert len(expected) == len(output) + np.array_equal(np.array(output), np.array(expected)) + + assert sum([1 for _ in data3]) == 5 + assert data3.get_dataset_size() == 5 + + +def test_concat_15(): + """ + Test concat: create dataset with different format of dataset file, and then concat + """ + logger.info("test_concat_15") + DATA_DIR = "../data/dataset/testPK/data" + DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] + + data1 = ds.ImageFolderDatasetV2(DATA_DIR) + data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"]) + + data1 = data1.project(["image"]) + data3 = data1 + data2 + + assert sum([1 for _ in data3]) == 47 + + +if __name__ == "__main__": + test_concat_01() + test_concat_02() + test_concat_03() + test_concat_04() + test_concat_05() + test_concat_06() + test_concat_07() + test_concat_08() + test_concat_09() + test_concat_10() + test_concat_11() + test_concat_12() + test_concat_13() + test_concat_14() + test_concat_15() \ No newline at end of file