diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index f6440710b1b..a02d995147a 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -29,6 +29,7 @@ #include "dataset/engine/datasetops/source/cifar_op.h" #include "dataset/engine/datasetops/source/celeba_op.h" #include "dataset/engine/datasetops/source/text_file_op.h" +#include "dataset/engine/datasetops/filter_op.h" #include "mindrecord/include/shard_category.h" #include "mindrecord/include/shard_sample.h" #include "mindrecord/include/shard_shuffle.h" @@ -45,6 +46,7 @@ static std::unordered_map g_parse_op_func_ = {{kStorage, &D {kShuffle, &DEPipeline::ParseShuffleOp}, {kMindrecord, &DEPipeline::ParseMindRecordOp}, {kMap, &DEPipeline::ParseMapOp}, + {kFilter, &DEPipeline::ParseFilterOp}, {kBatch, &DEPipeline::ParseBatchOp}, {kRepeat, &DEPipeline::ParseRepeatOp}, {kSkip, &DEPipeline::ParseSkipOp}, @@ -502,6 +504,41 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr * return Status::OK(); } +Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr *ptr) { + std::shared_ptr builder = std::make_shared(); + + if (args["predicate"].is_none()) { + RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n"); + } + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "predicate") { + py::handle op = args["predicate"]; + if (!py::isinstance(op)) { + RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc)."); + } + py::function predicate_func = op.cast(); + (void)builder->SetPredicateFunc(std::move(predicate_func)); + } else if (key == "input_columns") { + std::vector in_col_names = ToStringVector(args["input_columns"]); + (void)builder->SetInColNames(in_col_names); + } else { + RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} + Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr *ptr) { if (args["count"].is_none()) { std::string err_msg = "Error: count is invalid or not set."; @@ -671,8 +708,6 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr * return Status::OK(); } -DsOpPtr DEPipeline::ParseFilterOp(const py::dict &args) const { return DsOpPtr(); } - Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr *ptr) { // Required arguments std::shared_ptr builder = std::make_shared(); diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index eadde2c1910..25919afe588 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -107,6 +107,8 @@ class DEPipeline { Status ParseMapOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseFilterOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseRepeatOp(const py::dict &args, std::shared_ptr *ptr); Status ParseSkipOp(const py::dict &args, std::shared_ptr *ptr); @@ -121,8 +123,6 @@ class DEPipeline { Status ParseZipOp(const py::dict &args, std::shared_ptr *ptr); - DsOpPtr ParseFilterOp(const py::dict &args) const; - Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *ptr); Status ParseTFReaderOp(const py::dict &args, std::shared_ptr *ptr); diff --git a/mindspore/ccsrc/dataset/core/client.h b/mindspore/ccsrc/dataset/core/client.h index b865c542604..15064dee6b8 100644 --- a/mindspore/ccsrc/dataset/core/client.h +++ b/mindspore/ccsrc/dataset/core/client.h @@ -31,6 +31,7 @@ #include "dataset/engine/datasetops/map_op.h" #include "dataset/engine/datasetops/project_op.h" #include "dataset/engine/datasetops/rename_op.h" +#include "dataset/engine/datasetops/filter_op.h" #include "dataset/engine/datasetops/repeat_op.h" #include "dataset/engine/datasetops/skip_op.h" #include "dataset/engine/datasetops/shuffle_op.h" diff --git a/mindspore/ccsrc/dataset/core/tensor.cc b/mindspore/ccsrc/dataset/core/tensor.cc index a566d51f5c1..3f41f27726d 100644 --- a/mindspore/ccsrc/dataset/core/tensor.cc +++ b/mindspore/ccsrc/dataset/core/tensor.cc @@ -240,7 +240,7 @@ void Tensor::PrintItemAt(const std::vector &index, std::ostream &out) c DS_ASSERT(data_); switch (type_.value()) { - CASE_PRINT_HEX(DataType::DE_BOOL, uint8_t); + CASE_PRINT_HEX(DataType::DE_BOOL, bool); CASE_PRINT_HEX(DataType::DE_INT8, int8_t); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt index 655a739ada7..7de62d9d110 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt @@ -14,5 +14,6 @@ add_library(engine-datasetops OBJECT take_op.cc shuffle_op.cc zip_op.cc + filter_op.cc ) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc new file mode 100644 index 00000000000..22b1155fc9b --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc @@ -0,0 +1,273 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/filter_op.h" +#include +#include +#include +#include +#include +#include "dataset/core/config_manager.h" +#include "dataset/core/constants.h" +#include "dataset/core/global_context.h" +#include "dataset/core/tensor.h" +#include "dataset/engine/data_buffer.h" +#include "dataset/engine/db_connector.h" +#include "dataset/engine/execution_tree.h" +#include "dataset/kernels/tensor_op.h" +#include "utils/log_adapter.h" +#include "dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { + +Status FilterOp::Builder::SanityCheck() { + std::string err; + err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : ""; + err += builder_num_workers_ <= 0 ? "filter num_parallel_workers <= 0\n" : ""; + return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); +} + +FilterOp::Builder::Builder() { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status FilterOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(std::move(build_in_col_names_), builder_num_workers_, builder_op_connector_size_, + builder_predicate_func_); + return Status::OK(); +} + +FilterOp::FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, + py::function predicate_func) + : ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {} + +Status FilterOp::operator()() { + // The operator class just starts off threads by calling the tree_ function. + RETURN_UNEXPECTED_IF_NULL(tree_); + // Synchronize with TaskManager. + TaskManager::FindMe()->Post(); + filter_queues_.Init(num_workers_, oc_queue_size_); + RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1))); + RETURN_IF_NOT_OK(Collector()); + return Status::OK(); +} + +Status FilterOp::EofReceived(int32_t) { return Status::OK(); } + +Status FilterOp::EoeReceived(int32_t) { return Status::OK(); } + +// Validating if each of the input_columns exists in the DataBuffer. +Status FilterOp::ValidateInColumns(const std::unordered_map &col_name_id_map, + std::vector *input_columns) { + for (const auto &inCol : *input_columns) { + bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false; + if (!found) { + std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + return Status::OK(); +} + +// A print method typically used for debugging. +void FilterOp::Print(std::ostream &out, bool show_all) const { + // Call base class printer first. + ParallelOp::Print(out, show_all); + + // Then display our own stuff. + out << "\nFilterOp:"; + out << "\n Input column names:"; + for (size_t i = 0; i < in_columns_.size(); i++) { + out << " " << in_columns_[i]; + } +} + +Status FilterOp::WorkerEntry(int32_t worker_id) { + // Handshake with TaskManager that thread creation is successful. + TaskManager::FindMe()->Post(); + std::unique_ptr in_buffer; + bool worker_stop = false; + while (worker_stop == false) { + // Getting a databuffer to work on. + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); + if (in_buffer->eoe()) { + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); + continue; + } else if (in_buffer->eof()) { + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); + worker_stop = true; + continue; + } + + // Thread local variables to avoid lock. When in_columns_ is empty and workers will write + // the name of the first column into input_columns (thread local) instead of in_columns_ (thread global). + std::vector input_columns = in_columns_; + // Indices of the columns to process. + std::vector to_process_indices; + + RETURN_IF_NOT_OK(WorkerEntryInit(in_buffer.get(), &to_process_indices, &input_columns)); + + // if the databuffer was all filtered, it is marked as kFilterEmpty. + // if the databuffer was partially filtered, it is marked as kFilterPartial. + // if the databuffer was not filtered, it is marked as kFilterFull. + int32_t num_rows = in_buffer->NumRows(); + std::unique_ptr new_tensor_table; + RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), to_process_indices, &new_tensor_table)); + + if (new_tensor_table->empty()) { + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEmpty))); + } else if (new_tensor_table->size() == num_rows) { + in_buffer->set_tensor_table(std::move(new_tensor_table)); + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterFull))); + } else { // kFilterPartial + in_buffer->set_tensor_table(std::move(new_tensor_table)); + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterPartial))); + } + } + return Status::OK(); +} + +Status FilterOp::WorkerCompute(DataBuffer *in_buffer, const std::vector &to_proess_indices, + std::unique_ptr *out) { + *out = std::make_unique(); + int32_t num_rows = in_buffer->NumRows(); + for (int32_t i = 0; i < num_rows; i++) { + TensorRow to_process; + TensorRow cur_row; + RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); + + (void)std::transform(to_proess_indices.begin(), to_proess_indices.end(), std::back_inserter(to_process), + [&cur_row](const size_t &it) -> std::shared_ptr { return cur_row[it]; }); + bool predicate = true; + RETURN_IF_NOT_OK(InvokePredicateFunc(to_process, &predicate)); + if (predicate) { + (*out)->push_back(std::move(cur_row)); + } + } + return Status::OK(); +} + +// if the filtered DataBuffer is written directly to out_connector_, +// the thread fetching data will block in a queue. +// Collector function will reorder the DataBuffer in order. +// for example in two work queues: +// int filter_queues_: +// queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof) +// queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe) +// after reorder in out_connector_: +// queue1: DB(data2) DB(data4) DB(eof) +// queue2: DB(eoe) DB(eoe) +Status FilterOp::Collector() { + bool collector_stop = false; + uint64_t task_id_cnt = 0; + uint64_t out_id_cnt = 0; + std::pair, filterCtrl> in_pair; + while (collector_stop == false) { + uint32_t w_id = task_id_cnt % num_workers_; + RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair)); + if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial || + in_pair.second == filterCtrl::kFilterEoe) { + uint32_t out_task_id = out_id_cnt % num_workers_; + RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); + out_id_cnt++; + task_id_cnt++; + } else if (in_pair.second == filterCtrl::kFilterEof) { + uint32_t out_task_id = out_id_cnt % num_workers_; + RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); + collector_stop = true; + } else { // kFilterEmpty + task_id_cnt++; + } + } + return Status::OK(); +} + +// initialize some internal data structure used by WorkerEntry(). +Status FilterOp::WorkerEntryInit(const DataBuffer *in_buf, std::vector *to_process_indices, + std::vector *input_columns) { + int32_t num_rows = in_buf->NumRows(); + int32_t num_cols = in_buf->NumCols(); + if (num_rows == 0 || num_cols == 0) { + RETURN_STATUS_UNEXPECTED("FilterOp is getting an empty DataBuffer."); + } + std::unordered_map col_name_id_map = in_buf->column_name_map(); + // Check if there is invalid column name in the inColumns. + RETURN_IF_NOT_OK(ValidateInColumns(col_name_id_map, input_columns)); + + if (input_columns->empty()) { + MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table."; + // sort the input colunms by column index. + std::vector> sort_vec(col_name_id_map.begin(), col_name_id_map.end()); + std::sort(sort_vec.begin(), sort_vec.end(), + [](const std::pair &a, const std::pair &b) { + return a.second < b.second; + }); + + (void)std::transform(sort_vec.begin(), sort_vec.end(), std::back_inserter(*input_columns), + [](const auto &it) -> std::string { return it.first; }); + } + + // initialize to_process_indices. + (void)std::transform(input_columns->begin(), input_columns->end(), std::back_inserter(*to_process_indices), + [&col_name_id_map](const auto &it) -> size_t { return col_name_id_map[it]; }); + + return Status::OK(); +} + +Status FilterOp::CheckInput(const TensorRow &input) const { + for (auto &item : input) { + if (item == nullptr) { + RETURN_STATUS_UNEXPECTED("input is null."); + } + } + return Status::OK(); +} + +Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate) { + RETURN_IF_NOT_OK(CheckInput(input)); + // Acquire Python GIL. + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + // Transform input tensor vector into numpy array vector. + py::tuple input_args(input.size()); + for (size_t i = 0; i < input.size(); i++) { + py::array new_data; + RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data)); + input_args[i] = new_data; + } + // Invoke python function. + py::object ret_py_obj = predicate_func_(*input_args); + *out_predicate = ret_py_obj.cast(); + } catch (const py::error_already_set &e) { + std::stringstream ss; + ss << e.what() << std::endl; + ss << "The type of the return value of python predicate function is not bool, or can not be convert to bool."; + return Status(StatusCode::kPyFuncException, ss.str()); + } + return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h new file mode 100644 index 00000000000..50697d398f1 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h @@ -0,0 +1,180 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ + +#include +#include +#include +#include +#include +#include +#include "dataset/engine/datasetops/parallel_op.h" +#include "dataset/kernels/tensor_op.h" +#include "dataset/util/queue.h" + +namespace mindspore { +namespace dataset { + +class FilterOp : public ParallelOp { + public: + // The nested builder class inside of the FilterOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args. + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetPredicateFunc(py::function func) { + builder_predicate_func_ = std::move(func); + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetInColNames(const std::vector &in_col_names) { + build_in_col_names_ = in_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + builder_op_connector_size_ = connector_size; + return *this; + } + + // The builder "build" method creates the final object. + // @param ptr The shared_ptr to the new FilterOp object. + // @return Status. + Status Build(std::shared_ptr *ptr); + + private: + // Sanity check for builder class args. + // @return Status - The error code return. + Status SanityCheck(); + std::vector build_in_col_names_; + py::function builder_predicate_func_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + }; + + enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 }; + + // Constructor of FilterOp + // @note The builder class should be used to call it. + // @param in_col_names A list of input column names,when it is empty the predicate will be + // applied all columns in the dataset. + // @param num_workers The number of worker threads. + // @param op_connector_size The size of each queue in the connector. + // @param predicate_func python callable which returns a boolean value. + FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, + py::function predicate_func); + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will + // provide the master loop that drives the logic for performing the work. + // @return Status The error code return + Status operator()() override; + + // @param int32_t workerId. + // @return Status - The error code return. + Status EofReceived(int32_t) override; + + // @param int32_t workerId. + // @return Status - The error code return. + Status EoeReceived(int32_t) override; + + // A print method typically used for debugging. + // @param out The output stream to write output to. + // @param show_all A bool to control if you want to show all info or just a summary. + void Print(std::ostream &out, bool show_all) const override; + + private: + // predicate_func python callable which returns a boolean value. + py::function predicate_func_; + + // Variable to store the column name that will feed to predicate function. + std::vector in_columns_; + + // Internal queue for filter. + QueueList, filterCtrl>> filter_queues_; + + // Private function for worker/thread to loop continuously. It comprises the main + // logic of FilterOp, getting the data from previous Op, validating user specified column names, + // applying predicate to each of the data, filter the data when predicate result is false. + // @param worker_id The id assigned to this thread/worker upon creation. + // @return Status The error code return. + Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ + + // Filter the data by predicate function . + // @param in_buffer input data buffer. + // @param to_proess_indices Indices of columns to be processed. + // @param out data buffer that are filtered by predicate. + // @return Status The error code return. + Status WorkerCompute(DataBuffer *in_buffer, const std::vector &to_proess_indices, + std::unique_ptr *out); + + // Collector databuffer. + // @return Status The error code return. + Status Collector(); + + // @param input tensor vector. + // @return Status - The error code return. + Status CheckInput(const TensorRow &input) const; + + // Invoke python func. + // @param input tensor vector. + // @param the result of predicate. + // @return Status - The error code return. + Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate); + + // Private function for validating if each of the user specified input column names + // exist in the DataBuffer. + // @param col_name_id_map The column name to index mapping obtained from DataBuffer. + // @param input_columns The vector of input column names used in the current thread. + // @return Status The error code return. + Status ValidateInColumns(const std::unordered_map &col_name_id_map, + std::vector *input_columns); + + // Private function that initialize some internal data structure used by WorkerEntry(). + // @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory + // and is not shared with other threads. + // @param[out] to_process_indices Indices of columns that will feed to predicate. + // @param input_columns The vector of input column names used in the current thread. + Status WorkerEntryInit(const DataBuffer *in_buf, std::vector *to_process_indices, + std::vector *input_columns); +}; + +} // namespace dataset +} // namespace mindspore +#endif diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 7c4857a5806..57ce07b927b 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -35,7 +35,7 @@ from mindspore._c_expression import typing from mindspore import log as logger from . import samplers from .iterators import DictIterator, TupleIterator -from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \ +from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ check_zip_dataset, check_add_column, check_textfiledataset @@ -385,6 +385,32 @@ class Dataset: """ return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers) + @check_filter + def filter(self, predicate, input_columns=None, num_parallel_workers=1): + """ + Filter dataset by predicate. + + Note: + If input_columns not provided or empty, all columns will be used. + + Args: + predicate: python callable which returns a boolean value. + input_columns: (list[str]): List of names of the input columns, when + default=None, the predicate will be applied on all columns in the dataset. + num_parallel_workers (int, optional): Number of workers to process the Dataset + in parallel (default=None). + + Returns: + FilterDataset, dataset filter. + + Examples: + >>> import mindspore.dataset as ds + >>> # generator data(0 ~ 63) + >>> # filter the data that greater than or equal to 11 + >>> dataset_f = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"]) + """ + return FilterDataset(self, predicate, input_columns, num_parallel_workers) + @check_repeat def repeat(self, count=None): """ @@ -1105,6 +1131,44 @@ class MapDataset(DatasetOp): return self.input[0].get_dataset_size() +class FilterDataset(DatasetOp): + """ + The result of applying filter predicate to the input Dataset. + + Args: + input_dataset: Input Dataset to be mapped. + predicate: python callable which returns a boolean value. + input_columns: (list[str]): List of names of the input columns, when + default=None, the predicate will be applied all columns in the dataset. + num_parallel_workers (int, optional): Number of workers to process the Dataset + in parallel (default=None). + """ + + def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None): + super().__init__(num_parallel_workers) + self.predicate = lambda *args: bool(predicate(*args)) + self.input.append(input_dataset) + input_dataset.output.append(self) + if input_columns is not None and not isinstance(input_columns, list): + input_columns = [input_columns] + self.input_columns = input_columns + + def get_args(self): + args = super().get_args() + args["predicate"] = self.predicate + args["input_columns"] = self.input_columns + return args + + def get_dataset_size(self): + """ + Get the number of batches in an epoch. + the size cannot be determined before we run the pipeline + Return: + 0 + """ + return 0 + + class RepeatDataset(DatasetOp): """ The result of applying Repeat operator to the input Dataset. diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index a74d69b9c7b..6af6c7dba8e 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -129,6 +129,8 @@ class Iterator: op_type = OpName.ZIP elif isinstance(dataset, de.MapDataset): op_type = OpName.MAP + elif isinstance(dataset, de.FilterDataset): + op_type = OpName.FILTER elif isinstance(dataset, de.RepeatDataset): op_type = OpName.REPEAT elif isinstance(dataset, de.SkipDataset): diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 5de113fd727..a68d723f1d3 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -693,6 +693,26 @@ def check_map(method): return new_method +def check_filter(method): + """"check the input arguments of filter.""" + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + predicate = param_dict.get("predicate") + if not callable(predicate): + raise ValueError("Predicate should be a python function or a callable python object.") + + nreq_param_int = ['num_parallel_workers'] + check_param_type(nreq_param_int, param_dict, int) + param_name = "input_columns" + param = param_dict.get(param_name) + if param is not None: + check_columns(param, param_name) + return method(*args, **kwargs) + + return new_method + + def check_repeat(method): """check the input arguments of repeat.""" @wraps(method) diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index b05f12eee12..2224565c309 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -66,6 +66,8 @@ SET(DE_UT_SRCS celeba_op_test.cc take_op_test.cc text_file_op_test.cc) + filter_op_test.cc + ) add_executable(de_ut_tests ${DE_UT_SRCS}) diff --git a/tests/ut/cpp/dataset/filter_op_test.cc b/tests/ut/cpp/dataset/filter_op_test.cc new file mode 100644 index 00000000000..45ee714337e --- /dev/null +++ b/tests/ut/cpp/dataset/filter_op_test.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/util/circular_pool.h" +#include "dataset/core/client.h" +#include "common/common.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +namespace de = mindspore::dataset; + +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestfilter_op : public UT::DatasetOpTesting { + +}; + + +std::shared_ptr Filter() { + Status rc; + std::shared_ptr op; + rc = de::FilterOp::Builder().Build(&op); + EXPECT_TRUE(rc.IsOk()); + return op; +} + +TEST_F(MindDataTestfilter_op, Testfilter_opFuntions) { + MS_LOG(INFO) << "Doing MindDataTest filter_op."; + auto my_tree = std::make_shared(); + + std::shared_ptr parent_op = Filter(); + + std::shared_ptr leaf_op = Filter(); + my_tree->AssociateNode(parent_op); + my_tree->AssociateNode(leaf_op); + ASSERT_NE(parent_op, nullptr); + ASSERT_NE(leaf_op, nullptr); +} diff --git a/tests/ut/cpp/dataset/tensor_test.cc b/tests/ut/cpp/dataset/tensor_test.cc index 7437b3d9424..494d4b23290 100644 --- a/tests/ut/cpp/dataset/tensor_test.cc +++ b/tests/ut/cpp/dataset/tensor_test.cc @@ -158,6 +158,16 @@ TEST_F(MindDataTestTensorDE, InsertTensor) { ASSERT_EQ(*t == *t6, true); } +// Test the bug of Tensor::ToString will exec failed for Tensor which store bool values +TEST_F(MindDataTestTensorDE, BoolTensor) { + std::shared_ptr t = std::make_shared(TensorShape({2}), + DataType(DataType::DE_BOOL)); + t->SetItemAt({0}, true); + t->SetItemAt({1}, true); + std::string out = t->ToString(); + ASSERT_TRUE(out.find("Template type and Tensor type are not compatible") == std::string::npos); +} + TEST_F(MindDataTestTensorDE, GetItemAt) { std::shared_ptr t = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); t->Fill(254); diff --git a/tests/ut/data/dataset/declient_filter.cfg b/tests/ut/data/dataset/declient_filter.cfg new file mode 100644 index 00000000000..89e1199f5a2 --- /dev/null +++ b/tests/ut/data/dataset/declient_filter.cfg @@ -0,0 +1,3 @@ +{ + "rowsPerBuffer": 10, +} diff --git a/tests/ut/python/dataset/test_filterop.py b/tests/ut/python/dataset/test_filterop.py new file mode 100644 index 00000000000..90f512caa40 --- /dev/null +++ b/tests/ut/python/dataset/test_filterop.py @@ -0,0 +1,504 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as cde +import mindspore.dataset.transforms.c_transforms as C +import mindspore.common.dtype as mstype +from mindspore import log as logger + +DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" +# test for predicate +def test_diff_predicate_func(): + def test_filter(predicate_func): + transforms = [ + cde.Decode(), + cde.Resize([64, 64]) + ] + type_cast_op = C.TypeCast(mstype.int32) + dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], shuffle=False) + dataset = dataset.map(input_columns=["image"], operations=transforms, num_parallel_workers=1) + dataset = dataset.filter(input_columns=["image", "label"], predicate=predicate_func, num_parallel_workers=4) + + num_iter = 0 + label_list = [] + for data in dataset.create_dict_iterator(): + num_iter += 1 + ori_img = data["image"] + label = data["label"] + label_list.append(label) + assert num_iter == 1 + assert label_list[0] == 3 + + test_filter(lambda image, label: label == 3) + test_filter(lambda image, label: label[0] == 3) + test_filter(lambda image, label: label == [3]) + test_filter(lambda image, label: label == np.array([3])) + test_filter(lambda image, label: label == np.array(3)) + +def filter_func_ge(data): + if data > 10: + return False + return True + + +def generator_1d(): + for i in range(64): + yield (np.array(i),) + +# test with GeneratorDataset +def test_filter_by_generator_with_no(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_f = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) + num_iter = 0 + expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for item in dataset_f.create_dict_iterator(): + assert item["data"] == expected_rs[num_iter] + num_iter += 1 + +# test with repeatOp before +def test_filter_by_generator_with_repeat(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_r = dataset.repeat(4) + dataset_f = dataset_r.filter(predicate=filter_func_ge, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["data"]) + assert num_iter == 44 + for i in range(4): + for ii in range(len(expected_rs)): + index = i * len(expected_rs) + ii + assert ret_data[index] == expected_rs[ii] + +# test with repeatOp after +def test_filter_by_generator_with_repeat_after(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_f = dataset.filter(predicate=filter_func_ge, num_parallel_workers=4) + dataset_r = dataset_f.repeat(4) + num_iter = 0 + ret_data = [] + expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for item in dataset_r.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["data"]) + assert num_iter == 44 + for i in range(4): + for ii in range(len(expected_rs)): + index = i * len(expected_rs) + ii + assert ret_data[index] == expected_rs[ii] + +def filter_func_batch(data): + if data[0] > 8: + return False + return True + +def filter_func_batch_after(data): + if data > 20: + return False + return True + +# test with batchOp before +def test_filter_by_generator_with_batch(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_b = dataset.batch(4) + dataset_f = dataset_b.filter(predicate=filter_func_batch, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["data"]) + assert num_iter == 3 + assert ret_data[0][0] == 0 + assert ret_data[1][0] == 4 + assert ret_data[2][0] == 8 + +# test with batchOp after +def test_filter_by_generator_with_batch_after(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_f = dataset.filter(predicate=filter_func_batch_after, num_parallel_workers=4) + dataset_b = dataset_f.batch(4) + num_iter = 0 + ret_data = [] + for item in dataset_b.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["data"]) + assert num_iter == 6 + assert ret_data[0][0] == 0 + assert ret_data[1][0] == 4 + assert ret_data[5][0] == 20 + + +def filter_func_shuffle(data): + if data > 20: + return False + return True + +# test with batchOp before +def test_filter_by_generator_with_shuffle(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_s = dataset.shuffle(4) + dataset_f = dataset_s.filter(predicate=filter_func_shuffle, num_parallel_workers=4) + num_iter = 0 + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + assert num_iter == 21 + + +def filter_func_shuffle_after(data): + if data > 20: + return False + return True + +# test with batchOp after +def test_filter_by_generator_with_shuffle_after(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_f = dataset.filter(predicate=filter_func_shuffle_after, num_parallel_workers=4) + dataset_s = dataset_f.shuffle(4) + num_iter = 0 + for item in dataset_s.create_dict_iterator(): + num_iter += 1 + assert num_iter == 21 + + +def generator_1d_zip1(): + for i in range(64): + yield (np.array(i),) + + +def generator_1d_zip2(): + for i in range(64): + yield (np.array(i+100),) + + +def filter_func_zip(data1, data2): + if data1 > 20: + return False + return True + +def filter_func_zip_after(data1): + if data1 > 20: + return False + return True + +# test with zipOp before +def test_filter_by_generator_with_zip(): + dataset1 = ds.GeneratorDataset(generator_1d_zip1, ["data1"]) + dataset2 = ds.GeneratorDataset(generator_1d_zip2, ["data2"]) + dataz = ds.zip((dataset1, dataset2)) + dataset_f = dataz.filter(predicate=filter_func_zip, num_parallel_workers=1) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append({"data1": item["data1"], "data2":item["data2"]}) + assert num_iter == 21 + assert ret_data[0]["data1"] == 0 + assert ret_data[0]["data2"] == 100 + assert ret_data[5]["data1"] == 5 + assert ret_data[5]["data2"] == 105 + + +# test with zipOp after +def test_filter_by_generator_with_zip_after(): + dataset1 = ds.GeneratorDataset(generator_1d_zip1, ["data1"]) + dataset2 = ds.GeneratorDataset(generator_1d_zip1, ["data2"]) + dt1 = dataset1.filter(predicate=filter_func_zip_after, num_parallel_workers=4) + dt2 = dataset2.filter(predicate=filter_func_zip_after, num_parallel_workers=4) + dataz = ds.zip((dt1, dt2)) + num_iter = 0 + ret_data = [] + for item in dataz.create_dict_iterator(): + num_iter += 1 + ret_data.append({"data1": item["data1"], "data2":item["data2"]}) + assert num_iter == 21 + assert ret_data[0]["data1"] == 0 + assert ret_data[0]["data2"] == 0 + assert ret_data[5]["data1"] == 5 + assert ret_data[5]["data2"] == 5 + + +def filter_func_map(col1, col2): + if col1[0] > 8: + return True + return False + + +def filter_func_map_part(col1): + if col1 < 3: + return True + else: + return False + + +def filter_func_map_all(col1, col2): + return True + +def generator_mc(maxid=20): + for i in range(maxid): + yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) + + +def func_map(data_col1, data_col2): + return (data_col1, data_col2) + + +def func_map_part(data_col1): + return (data_col1) + +# test with map +def test_filter_by_generator_with_map_all_col(): + dataset = ds.GeneratorDataset(generator_mc(12), ["col1", "col2"]) + dataset_map = dataset.map( input_columns=["col1"], output_columns=["col1"] , operations=func_map_part) + # dataset_map = dataset.map( operations=func_map_part) + dataset_f = dataset_map.filter(input_columns=["col1"], predicate=filter_func_map_part, num_parallel_workers=1) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["col1"]) + assert num_iter == 3 + assert ret_data[0] == 0 + assert ret_data[1] == 1 + +# test with map +def test_filter_by_generator_with_map_part_col(): + dataset = ds.GeneratorDataset(generator_mc(12), ["col1", "col2"]) + dataset_map = dataset.map( input_columns=["col1"], output_columns=["out1"] , operations=func_map_part) + + dataset_f = dataset_map.filter(input_columns=["out1", "col2"], predicate=filter_func_map, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + print(item) + ret_data.append(item["out1"]) + assert num_iter == 3 + assert ret_data[0] == 9 + assert ret_data[2] == 11 + + +def filter_func_rename(data): + if data> 8: + return True + return False + +# test with rename before +def test_filter_by_generator_with_rename(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_b = dataset.rename(input_columns=["data"], output_columns=["col1"]) + dataset_f = dataset_b.filter(predicate=filter_func_rename, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["col1"]) + assert num_iter == 55 + assert ret_data[0] == 9 + assert ret_data[54] == 63 + + +#test input_column +def filter_func_input_column1(col1, col2): + if col1[0] < 8: + return True + return False + +def filter_func_input_column2(col1): + if col1[0] < 8: + return True + return False + +def filter_func_input_column3(col1): + return True + +# test with input_columns +def test_filter_by_generator_with_input_column(): + dataset = ds.GeneratorDataset(generator_mc(64), ["col1", "col2"]) + dataset_map = dataset.map( input_columns=["col1"], output_columns=["out1"] , operations=func_map_part) + dataset_f1 = dataset_map.filter(input_columns=["out1", "col2"], predicate=filter_func_input_column1, num_parallel_workers=4) + dataset_f2 = dataset_f1.filter(input_columns=["out1"], predicate=filter_func_input_column2, num_parallel_workers=4) + dataset_f3 = dataset_f2.filter(input_columns=["col2"], predicate=filter_func_input_column3, num_parallel_workers=4) + dataset_f4 = dataset_f3.filter(predicate=filter_func_input_column1, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + for item in dataset_f4.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["out1"]) + assert num_iter == 8 + assert ret_data[0] == 0 + assert ret_data[7] == 7 + + +#test kFilterPartial +def generator_mc_p0(maxid=20): + for i in range(maxid): + yield (np.array([i ]), np.array([i + 100])) + +def generator_mc_p1(maxid=20): + for i in range(maxid): + yield (np.array([i + 200 ]), np.array([i + 300])) + + +def filter_func_Partial_0(col1, col2, col3, col4): + filter_data = [0,1,2,3,4, 11] + if col1[0] in filter_data: + return False + return True + +# test with row_data_buffer > 1 +def test_filter_by_generator_Partial0(): + ds.config.load('../data/dataset/declient_filter.cfg') + dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"]) + dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"]) + dataset_zip = ds.zip((dataset1, dataset2)) + dataset_f1 = dataset_zip.filter(predicate=filter_func_Partial_0, num_parallel_workers=2) + ret = [] + for item in dataset_f1.create_dict_iterator(): + ret.append(item["col1"]) + assert ret[0] == 5 + assert ret[6] == 12 + +# test with row_data_buffer > 1 +def test_filter_by_generator_Partial1(): + ds.config.load('../data/dataset/declient_filter.cfg') + dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"]) + dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"]) + dataset_zip = ds.zip((dataset1, dataset2)) + dataset_f1 = dataset_zip.filter(predicate=filter_func_Partial_0, num_parallel_workers=2) + dataset_map = dataset_f1.map( input_columns=["col1"], output_columns=["out1"] , operations=lambda x1: x1 + 400) + ret = [] + for item in dataset_map.create_dict_iterator(): + ret.append(item["out1"]) + assert ret[0] == 405 + assert ret[6] == 412 + +# test with row_data_buffer > 1 +def test_filter_by_generator_Partial2(): + ds.config.load('../data/dataset/declient_filter.cfg') + dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"]) + dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"]) + + dataset1f = dataset1.filter( input_columns= ["col1"], predicate=lambda x: x not in [3,7,9], num_parallel_workers=2) + dataset2f = dataset2.filter( input_columns= ["col3"], predicate=lambda x: x not in [203,207,209], num_parallel_workers=2) + dataset_zip = ds.zip((dataset1f, dataset2f)) + dataset_map = dataset_zip.map( input_columns=["col1", "col3"], output_columns=["out1", "out3"] , operations=lambda x1,x3: (x1 + 400, x3+500)) + ret1 = [] + ret3 = [] + for item in dataset_map.create_dict_iterator(): + ret1.append(item["out1"]) + ret3.append(item["out3"]) + assert ret1[0] == 400 + assert ret1[6] == 408 + assert ret3[0] == 700 + assert ret3[6] == 708 + + +def filter_func_Partial(col1, col2): + if col1[0] % 3 == 0: + return True + return False + +def generator_big(maxid=20): + for i in range(maxid): + yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) + +# test with row_data_buffer > 1 +def test_filter_by_generator_Partial(): + ds.config.load('../data/dataset/declient_filter.cfg') + dataset = ds.GeneratorDataset(source = generator_mc(99), column_names = ["col1", "col2"]) + dataset_s = dataset.shuffle(4) + dataset_f1 = dataset_s.filter(input_columns=["col1", "col2"], predicate=filter_func_Partial, num_parallel_workers=1) + + for item in dataset_f1.create_dict_iterator(): + assert item["col1"] % 3 == 0 + + +def filter_func_cifar(col1, col2): + if col2 % 3 == 0: + return True + return False + +# test with cifar10 +def test_filte_case_dataset_cifar10(): + DATA_DIR_10 = "../data/dataset/testCifar10Data" + ds.config.load('../data/dataset/declient_filter.cfg') + dataset_c = ds.Cifar10Dataset(dataset_dir = DATA_DIR_10, num_samples = 100000, shuffle=False) + dataset_f1 = dataset_c.filter(input_columns=["image", "label"], predicate=filter_func_cifar, num_parallel_workers=1) + num_iter = 0 + for item in dataset_f1.create_dict_iterator(): + # in this example, each dictionary has keys "image" and "label" + assert item["label"] % 3 == 0 + +# column id sort + +def generator_sort1(maxid=20): + for i in range(maxid): + yield (np.array([i]), np.array([i + 100]), np.array([i + 200])) + +def generator_sort2(maxid=20): + for i in range(maxid): + yield (np.array([i + 300]), np.array([i + 400]), np.array([i + 500])) + + +def filter_func_part_sort(col1, col2, col3, col4, col5, col6): + return True + +def filter_func_map_sort(col1, col2, col3): + return (col1, col2, col3) + +def test_filter_by_generator_with_map_all_sort(): + dataset1 = ds.GeneratorDataset(generator_sort1(10), ["col1", "col2", "col3"]) + dataset2 = ds.GeneratorDataset(generator_sort2(10), ["col4 ", "col5", "col6"]) + + dataz = ds.zip((dataset1, dataset2)) + dataset_f = dataz.filter(predicate=filter_func_part_sort, num_parallel_workers=1) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item) + + assert num_iter == 10 + assert ret_data[0]["col1"] == 0 + assert ret_data[9]["col6"] == 509 + + + +if __name__ == '__main__': + test_diff_predicate_func() + test_filte_case_dataset_cifar10() + test_filter_by_generator_Partial0() + test_filter_by_generator_Partial1() + test_filter_by_generator_Partial2() + test_filter_by_generator_with_batch() + test_filter_by_generator_with_batch_after() + test_filter_by_generator_with_input_column() + test_filter_by_generator_with_map_all_col() + test_filter_by_generator_with_map_all_sort() + test_filter_by_generator_with_map_part_col() + test_filter_by_generator_with_no() + test_filter_by_generator_with_rename() + test_filter_by_generator_with_repeat() + test_filter_by_generator_with_repeat_after() + test_filter_by_generator_with_shuffle() + test_filter_by_generator_with_shuffle_after() + test_filter_by_generator_with_zip() + test_filter_by_generator_with_zip_after() + test_filter_by_generator_Partial() diff --git a/tests/ut/python/dataset/test_iterator.py b/tests/ut/python/dataset/test_iterator.py index 102fd0eea1c..7c69adf5616 100644 --- a/tests/ut/python/dataset/test_iterator.py +++ b/tests/ut/python/dataset/test_iterator.py @@ -25,8 +25,8 @@ COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", def check(project_columns): - data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS) - data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=project_columns) + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS, shuffle=False) + data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=project_columns, shuffle=False) for data_actual, data_expected in zip(data1.create_tuple_iterator(project_columns), data2.create_tuple_iterator()): assert len(data_actual) == len(data_expected)