forked from mindspore-Ecosystem/mindspore
add filterOp code
This commit is contained in:
parent
d8176a77f4
commit
c705ea5e5b
|
@ -29,6 +29,7 @@
|
|||
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "dataset/engine/datasetops/source/celeba_op.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "dataset/engine/datasetops/filter_op.h"
|
||||
#include "mindrecord/include/shard_category.h"
|
||||
#include "mindrecord/include/shard_sample.h"
|
||||
#include "mindrecord/include/shard_shuffle.h"
|
||||
|
@ -45,6 +46,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
|
|||
{kShuffle, &DEPipeline::ParseShuffleOp},
|
||||
{kMindrecord, &DEPipeline::ParseMindRecordOp},
|
||||
{kMap, &DEPipeline::ParseMapOp},
|
||||
{kFilter, &DEPipeline::ParseFilterOp},
|
||||
{kBatch, &DEPipeline::ParseBatchOp},
|
||||
{kRepeat, &DEPipeline::ParseRepeatOp},
|
||||
{kSkip, &DEPipeline::ParseSkipOp},
|
||||
|
@ -502,6 +504,41 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
std::shared_ptr<FilterOp::Builder> builder = std::make_shared<FilterOp::Builder>();
|
||||
|
||||
if (args["predicate"].is_none()) {
|
||||
RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n");
|
||||
}
|
||||
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "predicate") {
|
||||
py::handle op = args["predicate"];
|
||||
if (!py::isinstance<py::function>(op)) {
|
||||
RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc).");
|
||||
}
|
||||
py::function predicate_func = op.cast<py::function>();
|
||||
(void)builder->SetPredicateFunc(std::move(predicate_func));
|
||||
} else if (key == "input_columns") {
|
||||
std::vector<std::string> in_col_names = ToStringVector(args["input_columns"]);
|
||||
(void)builder->SetInColNames(in_col_names);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<FilterOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*ptr = op;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
if (args["count"].is_none()) {
|
||||
std::string err_msg = "Error: count is invalid or not set.";
|
||||
|
@ -671,8 +708,6 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
DsOpPtr DEPipeline::ParseFilterOp(const py::dict &args) const { return DsOpPtr(); }
|
||||
|
||||
Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
// Required arguments
|
||||
std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>();
|
||||
|
|
|
@ -107,6 +107,8 @@ class DEPipeline {
|
|||
|
||||
Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
@ -121,8 +123,6 @@ class DEPipeline {
|
|||
|
||||
Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
DsOpPtr ParseFilterOp(const py::dict &args) const;
|
||||
|
||||
Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "dataset/engine/datasetops/map_op.h"
|
||||
#include "dataset/engine/datasetops/project_op.h"
|
||||
#include "dataset/engine/datasetops/rename_op.h"
|
||||
#include "dataset/engine/datasetops/filter_op.h"
|
||||
#include "dataset/engine/datasetops/repeat_op.h"
|
||||
#include "dataset/engine/datasetops/skip_op.h"
|
||||
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||
|
|
|
@ -240,7 +240,7 @@ void Tensor::PrintItemAt(const std::vector<dsize_t> &index, std::ostream &out) c
|
|||
DS_ASSERT(data_);
|
||||
|
||||
switch (type_.value()) {
|
||||
CASE_PRINT_HEX(DataType::DE_BOOL, uint8_t);
|
||||
CASE_PRINT_HEX(DataType::DE_BOOL, bool);
|
||||
|
||||
CASE_PRINT_HEX(DataType::DE_INT8, int8_t);
|
||||
|
||||
|
|
|
@ -14,5 +14,6 @@ add_library(engine-datasetops OBJECT
|
|||
take_op.cc
|
||||
shuffle_op.cc
|
||||
zip_op.cc
|
||||
filter_op.cc
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,273 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/datasetops/filter_op.h"
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/core/constants.h"
|
||||
#include "dataset/core/global_context.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status FilterOp::Builder::SanityCheck() {
|
||||
std::string err;
|
||||
err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : "";
|
||||
err += builder_num_workers_ <= 0 ? "filter num_parallel_workers <= 0\n" : "";
|
||||
return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
|
||||
}
|
||||
|
||||
FilterOp::Builder::Builder() {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_num_workers_ = cfg->num_parallel_workers();
|
||||
builder_op_connector_size_ = cfg->op_connector_size();
|
||||
}
|
||||
|
||||
Status FilterOp::Builder::Build(std::shared_ptr<FilterOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<FilterOp>(std::move(build_in_col_names_), builder_num_workers_, builder_op_connector_size_,
|
||||
builder_predicate_func_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
FilterOp::FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size,
|
||||
py::function predicate_func)
|
||||
: ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {}
|
||||
|
||||
Status FilterOp::operator()() {
|
||||
// The operator class just starts off threads by calling the tree_ function.
|
||||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
// Synchronize with TaskManager.
|
||||
TaskManager::FindMe()->Post();
|
||||
filter_queues_.Init(num_workers_, oc_queue_size_);
|
||||
RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1)));
|
||||
RETURN_IF_NOT_OK(Collector());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FilterOp::EofReceived(int32_t) { return Status::OK(); }
|
||||
|
||||
Status FilterOp::EoeReceived(int32_t) { return Status::OK(); }
|
||||
|
||||
// Validating if each of the input_columns exists in the DataBuffer.
|
||||
Status FilterOp::ValidateInColumns(const std::unordered_map<std::string, int32_t> &col_name_id_map,
|
||||
std::vector<std::string> *input_columns) {
|
||||
for (const auto &inCol : *input_columns) {
|
||||
bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false;
|
||||
if (!found) {
|
||||
std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A print method typically used for debugging.
|
||||
void FilterOp::Print(std::ostream &out, bool show_all) const {
|
||||
// Call base class printer first.
|
||||
ParallelOp::Print(out, show_all);
|
||||
|
||||
// Then display our own stuff.
|
||||
out << "\nFilterOp:";
|
||||
out << "\n Input column names:";
|
||||
for (size_t i = 0; i < in_columns_.size(); i++) {
|
||||
out << " " << in_columns_[i];
|
||||
}
|
||||
}
|
||||
|
||||
Status FilterOp::WorkerEntry(int32_t worker_id) {
|
||||
// Handshake with TaskManager that thread creation is successful.
|
||||
TaskManager::FindMe()->Post();
|
||||
std::unique_ptr<DataBuffer> in_buffer;
|
||||
bool worker_stop = false;
|
||||
while (worker_stop == false) {
|
||||
// Getting a databuffer to work on.
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id));
|
||||
if (in_buffer->eoe()) {
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe));
|
||||
continue;
|
||||
} else if (in_buffer->eof()) {
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof));
|
||||
worker_stop = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Thread local variables to avoid lock. When in_columns_ is empty and workers will write
|
||||
// the name of the first column into input_columns (thread local) instead of in_columns_ (thread global).
|
||||
std::vector<std::string> input_columns = in_columns_;
|
||||
// Indices of the columns to process.
|
||||
std::vector<size_t> to_process_indices;
|
||||
|
||||
RETURN_IF_NOT_OK(WorkerEntryInit(in_buffer.get(), &to_process_indices, &input_columns));
|
||||
|
||||
// if the databuffer was all filtered, it is marked as kFilterEmpty.
|
||||
// if the databuffer was partially filtered, it is marked as kFilterPartial.
|
||||
// if the databuffer was not filtered, it is marked as kFilterFull.
|
||||
int32_t num_rows = in_buffer->NumRows();
|
||||
std::unique_ptr<TensorQTable> new_tensor_table;
|
||||
RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), to_process_indices, &new_tensor_table));
|
||||
|
||||
if (new_tensor_table->empty()) {
|
||||
RETURN_IF_NOT_OK(
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEmpty)));
|
||||
} else if (new_tensor_table->size() == num_rows) {
|
||||
in_buffer->set_tensor_table(std::move(new_tensor_table));
|
||||
RETURN_IF_NOT_OK(
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterFull)));
|
||||
} else { // kFilterPartial
|
||||
in_buffer->set_tensor_table(std::move(new_tensor_table));
|
||||
RETURN_IF_NOT_OK(
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterPartial)));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FilterOp::WorkerCompute(DataBuffer *in_buffer, const std::vector<size_t> &to_proess_indices,
|
||||
std::unique_ptr<TensorQTable> *out) {
|
||||
*out = std::make_unique<TensorQTable>();
|
||||
int32_t num_rows = in_buffer->NumRows();
|
||||
for (int32_t i = 0; i < num_rows; i++) {
|
||||
TensorRow to_process;
|
||||
TensorRow cur_row;
|
||||
RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row));
|
||||
|
||||
(void)std::transform(to_proess_indices.begin(), to_proess_indices.end(), std::back_inserter(to_process),
|
||||
[&cur_row](const size_t &it) -> std::shared_ptr<Tensor> { return cur_row[it]; });
|
||||
bool predicate = true;
|
||||
RETURN_IF_NOT_OK(InvokePredicateFunc(to_process, &predicate));
|
||||
if (predicate) {
|
||||
(*out)->push_back(std::move(cur_row));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// if the filtered DataBuffer is written directly to out_connector_,
|
||||
// the thread fetching data will block in a queue.
|
||||
// Collector function will reorder the DataBuffer in order.
|
||||
// for example in two work queues:
|
||||
// int filter_queues_:
|
||||
// queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof)
|
||||
// queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe)
|
||||
// after reorder in out_connector_:
|
||||
// queue1: DB(data2) DB(data4) DB(eof)
|
||||
// queue2: DB(eoe) DB(eoe)
|
||||
Status FilterOp::Collector() {
|
||||
bool collector_stop = false;
|
||||
uint64_t task_id_cnt = 0;
|
||||
uint64_t out_id_cnt = 0;
|
||||
std::pair<std::unique_ptr<DataBuffer>, filterCtrl> in_pair;
|
||||
while (collector_stop == false) {
|
||||
uint32_t w_id = task_id_cnt % num_workers_;
|
||||
RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair));
|
||||
if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial ||
|
||||
in_pair.second == filterCtrl::kFilterEoe) {
|
||||
uint32_t out_task_id = out_id_cnt % num_workers_;
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(static_cast<int>(out_task_id), std::move(in_pair.first)));
|
||||
out_id_cnt++;
|
||||
task_id_cnt++;
|
||||
} else if (in_pair.second == filterCtrl::kFilterEof) {
|
||||
uint32_t out_task_id = out_id_cnt % num_workers_;
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(static_cast<int>(out_task_id), std::move(in_pair.first)));
|
||||
collector_stop = true;
|
||||
} else { // kFilterEmpty
|
||||
task_id_cnt++;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// initialize some internal data structure used by WorkerEntry().
|
||||
Status FilterOp::WorkerEntryInit(const DataBuffer *in_buf, std::vector<size_t> *to_process_indices,
|
||||
std::vector<std::string> *input_columns) {
|
||||
int32_t num_rows = in_buf->NumRows();
|
||||
int32_t num_cols = in_buf->NumCols();
|
||||
if (num_rows == 0 || num_cols == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("FilterOp is getting an empty DataBuffer.");
|
||||
}
|
||||
std::unordered_map<std::string, int32_t> col_name_id_map = in_buf->column_name_map();
|
||||
// Check if there is invalid column name in the inColumns.
|
||||
RETURN_IF_NOT_OK(ValidateInColumns(col_name_id_map, input_columns));
|
||||
|
||||
if (input_columns->empty()) {
|
||||
MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table.";
|
||||
// sort the input colunms by column index.
|
||||
std::vector<std::pair<std::string, int32_t>> sort_vec(col_name_id_map.begin(), col_name_id_map.end());
|
||||
std::sort(sort_vec.begin(), sort_vec.end(),
|
||||
[](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
|
||||
return a.second < b.second;
|
||||
});
|
||||
|
||||
(void)std::transform(sort_vec.begin(), sort_vec.end(), std::back_inserter(*input_columns),
|
||||
[](const auto &it) -> std::string { return it.first; });
|
||||
}
|
||||
|
||||
// initialize to_process_indices.
|
||||
(void)std::transform(input_columns->begin(), input_columns->end(), std::back_inserter(*to_process_indices),
|
||||
[&col_name_id_map](const auto &it) -> size_t { return col_name_id_map[it]; });
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FilterOp::CheckInput(const TensorRow &input) const {
|
||||
for (auto &item : input) {
|
||||
if (item == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("input is null.");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate) {
|
||||
RETURN_IF_NOT_OK(CheckInput(input));
|
||||
// Acquire Python GIL.
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||
}
|
||||
try {
|
||||
// Transform input tensor vector into numpy array vector.
|
||||
py::tuple input_args(input.size());
|
||||
for (size_t i = 0; i < input.size(); i++) {
|
||||
py::array new_data;
|
||||
RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data));
|
||||
input_args[i] = new_data;
|
||||
}
|
||||
// Invoke python function.
|
||||
py::object ret_py_obj = predicate_func_(*input_args);
|
||||
*out_predicate = ret_py_obj.cast<py::bool_>();
|
||||
} catch (const py::error_already_set &e) {
|
||||
std::stringstream ss;
|
||||
ss << e.what() << std::endl;
|
||||
ss << "The type of the return value of python predicate function is not bool, or can not be convert to bool.";
|
||||
return Status(StatusCode::kPyFuncException, ss.str());
|
||||
}
|
||||
return Status(StatusCode::kOK, "FilterOp predicate func call succeed");
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,180 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_FILTER_OP_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_FILTER_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "dataset/engine/datasetops/parallel_op.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
#include "dataset/util/queue.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class FilterOp : public ParallelOp {
|
||||
public:
|
||||
// The nested builder class inside of the FilterOp is used to help manage all of
|
||||
// the arguments for constructing it. Use the builder by setting each argument
|
||||
// with the provided set methods, and then finally call the build method to execute
|
||||
// the actual construction.
|
||||
class Builder {
|
||||
public:
|
||||
// Builder constructor. Creates the builder object.
|
||||
// @note No default args.
|
||||
// @return This is a constructor.
|
||||
Builder();
|
||||
|
||||
// Default destructor
|
||||
~Builder() = default;
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetPredicateFunc(py::function func) {
|
||||
builder_predicate_func_ = std::move(func);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetInColNames(const std::vector<std::string> &in_col_names) {
|
||||
build_in_col_names_ = in_col_names;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumWorkers(int32_t num_workers) {
|
||||
builder_num_workers_ = num_workers;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetOpConnectorSize(int32_t connector_size) {
|
||||
builder_op_connector_size_ = connector_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @param ptr The shared_ptr to the new FilterOp object.
|
||||
// @return Status.
|
||||
Status Build(std::shared_ptr<FilterOp> *ptr);
|
||||
|
||||
private:
|
||||
// Sanity check for builder class args.
|
||||
// @return Status - The error code return.
|
||||
Status SanityCheck();
|
||||
std::vector<std::string> build_in_col_names_;
|
||||
py::function builder_predicate_func_;
|
||||
int32_t builder_num_workers_;
|
||||
int32_t builder_op_connector_size_;
|
||||
};
|
||||
|
||||
enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 };
|
||||
|
||||
// Constructor of FilterOp
|
||||
// @note The builder class should be used to call it.
|
||||
// @param in_col_names A list of input column names,when it is empty the predicate will be
|
||||
// applied all columns in the dataset.
|
||||
// @param num_workers The number of worker threads.
|
||||
// @param op_connector_size The size of each queue in the connector.
|
||||
// @param predicate_func python callable which returns a boolean value.
|
||||
FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size,
|
||||
py::function predicate_func);
|
||||
|
||||
// Class functor operator () override.
|
||||
// All dataset ops operate by launching a thread (see ExecutionTree),This class functor will
|
||||
// provide the master loop that drives the logic for performing the work.
|
||||
// @return Status The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// @param int32_t workerId.
|
||||
// @return Status - The error code return.
|
||||
Status EofReceived(int32_t) override;
|
||||
|
||||
// @param int32_t workerId.
|
||||
// @return Status - The error code return.
|
||||
Status EoeReceived(int32_t) override;
|
||||
|
||||
// A print method typically used for debugging.
|
||||
// @param out The output stream to write output to.
|
||||
// @param show_all A bool to control if you want to show all info or just a summary.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
private:
|
||||
// predicate_func python callable which returns a boolean value.
|
||||
py::function predicate_func_;
|
||||
|
||||
// Variable to store the column name that will feed to predicate function.
|
||||
std::vector<std::string> in_columns_;
|
||||
|
||||
// Internal queue for filter.
|
||||
QueueList<std::pair<std::unique_ptr<DataBuffer>, filterCtrl>> filter_queues_;
|
||||
|
||||
// Private function for worker/thread to loop continuously. It comprises the main
|
||||
// logic of FilterOp, getting the data from previous Op, validating user specified column names,
|
||||
// applying predicate to each of the data, filter the data when predicate result is false.
|
||||
// @param worker_id The id assigned to this thread/worker upon creation.
|
||||
// @return Status The error code return.
|
||||
Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_
|
||||
|
||||
// Filter the data by predicate function .
|
||||
// @param in_buffer input data buffer.
|
||||
// @param to_proess_indices Indices of columns to be processed.
|
||||
// @param out data buffer that are filtered by predicate.
|
||||
// @return Status The error code return.
|
||||
Status WorkerCompute(DataBuffer *in_buffer, const std::vector<size_t> &to_proess_indices,
|
||||
std::unique_ptr<TensorQTable> *out);
|
||||
|
||||
// Collector databuffer.
|
||||
// @return Status The error code return.
|
||||
Status Collector();
|
||||
|
||||
// @param input tensor vector.
|
||||
// @return Status - The error code return.
|
||||
Status CheckInput(const TensorRow &input) const;
|
||||
|
||||
// Invoke python func.
|
||||
// @param input tensor vector.
|
||||
// @param the result of predicate.
|
||||
// @return Status - The error code return.
|
||||
Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate);
|
||||
|
||||
// Private function for validating if each of the user specified input column names
|
||||
// exist in the DataBuffer.
|
||||
// @param col_name_id_map The column name to index mapping obtained from DataBuffer.
|
||||
// @param input_columns The vector of input column names used in the current thread.
|
||||
// @return Status The error code return.
|
||||
Status ValidateInColumns(const std::unordered_map<std::string, int32_t> &col_name_id_map,
|
||||
std::vector<std::string> *input_columns);
|
||||
|
||||
// Private function that initialize some internal data structure used by WorkerEntry().
|
||||
// @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory
|
||||
// and is not shared with other threads.
|
||||
// @param[out] to_process_indices Indices of columns that will feed to predicate.
|
||||
// @param input_columns The vector of input column names used in the current thread.
|
||||
Status WorkerEntryInit(const DataBuffer *in_buf, std::vector<size_t> *to_process_indices,
|
||||
std::vector<std::string> *input_columns);
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif
|
|
@ -35,7 +35,7 @@ from mindspore._c_expression import typing
|
|||
from mindspore import log as logger
|
||||
from . import samplers
|
||||
from .iterators import DictIterator, TupleIterator
|
||||
from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \
|
||||
from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \
|
||||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
|
||||
check_zip_dataset, check_add_column, check_textfiledataset
|
||||
|
@ -385,6 +385,32 @@ class Dataset:
|
|||
"""
|
||||
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers)
|
||||
|
||||
@check_filter
|
||||
def filter(self, predicate, input_columns=None, num_parallel_workers=1):
|
||||
"""
|
||||
Filter dataset by predicate.
|
||||
|
||||
Note:
|
||||
If input_columns not provided or empty, all columns will be used.
|
||||
|
||||
Args:
|
||||
predicate: python callable which returns a boolean value.
|
||||
input_columns: (list[str]): List of names of the input columns, when
|
||||
default=None, the predicate will be applied on all columns in the dataset.
|
||||
num_parallel_workers (int, optional): Number of workers to process the Dataset
|
||||
in parallel (default=None).
|
||||
|
||||
Returns:
|
||||
FilterDataset, dataset filter.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> # generator data(0 ~ 63)
|
||||
>>> # filter the data that greater than or equal to 11
|
||||
>>> dataset_f = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"])
|
||||
"""
|
||||
return FilterDataset(self, predicate, input_columns, num_parallel_workers)
|
||||
|
||||
@check_repeat
|
||||
def repeat(self, count=None):
|
||||
"""
|
||||
|
@ -1105,6 +1131,44 @@ class MapDataset(DatasetOp):
|
|||
return self.input[0].get_dataset_size()
|
||||
|
||||
|
||||
class FilterDataset(DatasetOp):
|
||||
"""
|
||||
The result of applying filter predicate to the input Dataset.
|
||||
|
||||
Args:
|
||||
input_dataset: Input Dataset to be mapped.
|
||||
predicate: python callable which returns a boolean value.
|
||||
input_columns: (list[str]): List of names of the input columns, when
|
||||
default=None, the predicate will be applied all columns in the dataset.
|
||||
num_parallel_workers (int, optional): Number of workers to process the Dataset
|
||||
in parallel (default=None).
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.predicate = lambda *args: bool(predicate(*args))
|
||||
self.input.append(input_dataset)
|
||||
input_dataset.output.append(self)
|
||||
if input_columns is not None and not isinstance(input_columns, list):
|
||||
input_columns = [input_columns]
|
||||
self.input_columns = input_columns
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["predicate"] = self.predicate
|
||||
args["input_columns"] = self.input_columns
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
"""
|
||||
Get the number of batches in an epoch.
|
||||
the size cannot be determined before we run the pipeline
|
||||
Return:
|
||||
0
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
class RepeatDataset(DatasetOp):
|
||||
"""
|
||||
The result of applying Repeat operator to the input Dataset.
|
||||
|
|
|
@ -129,6 +129,8 @@ class Iterator:
|
|||
op_type = OpName.ZIP
|
||||
elif isinstance(dataset, de.MapDataset):
|
||||
op_type = OpName.MAP
|
||||
elif isinstance(dataset, de.FilterDataset):
|
||||
op_type = OpName.FILTER
|
||||
elif isinstance(dataset, de.RepeatDataset):
|
||||
op_type = OpName.REPEAT
|
||||
elif isinstance(dataset, de.SkipDataset):
|
||||
|
|
|
@ -695,6 +695,26 @@ def check_map(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_filter(method):
|
||||
""""check the input arguments of filter."""
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
predicate = param_dict.get("predicate")
|
||||
if not callable(predicate):
|
||||
raise ValueError("Predicate should be a python function or a callable python object.")
|
||||
|
||||
nreq_param_int = ['num_parallel_workers']
|
||||
check_param_type(nreq_param_int, param_dict, int)
|
||||
param_name = "input_columns"
|
||||
param = param_dict.get(param_name)
|
||||
if param is not None:
|
||||
check_columns(param, param_name)
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_repeat(method):
|
||||
"""check the input arguments of repeat."""
|
||||
@wraps(method)
|
||||
|
|
|
@ -66,6 +66,8 @@ SET(DE_UT_SRCS
|
|||
celeba_op_test.cc
|
||||
take_op_test.cc
|
||||
text_file_op_test.cc)
|
||||
filter_op_test.cc
|
||||
)
|
||||
|
||||
add_executable(de_ut_tests ${DE_UT_SRCS})
|
||||
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/util/circular_pool.h"
|
||||
#include "dataset/core/client.h"
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
namespace de = mindspore::dataset;
|
||||
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
|
||||
class MindDataTestfilter_op : public UT::DatasetOpTesting {
|
||||
|
||||
};
|
||||
|
||||
|
||||
std::shared_ptr<de::FilterOp> Filter() {
|
||||
Status rc;
|
||||
std::shared_ptr<de::FilterOp> op;
|
||||
rc = de::FilterOp::Builder().Build(&op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
return op;
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestfilter_op, Testfilter_opFuntions) {
|
||||
MS_LOG(INFO) << "Doing MindDataTest filter_op.";
|
||||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::shared_ptr<DatasetOp> parent_op = Filter();
|
||||
|
||||
std::shared_ptr<DatasetOp> leaf_op = Filter();
|
||||
my_tree->AssociateNode(parent_op);
|
||||
my_tree->AssociateNode(leaf_op);
|
||||
ASSERT_NE(parent_op, nullptr);
|
||||
ASSERT_NE(leaf_op, nullptr);
|
||||
}
|
|
@ -158,6 +158,16 @@ TEST_F(MindDataTestTensorDE, InsertTensor) {
|
|||
ASSERT_EQ(*t == *t6, true);
|
||||
}
|
||||
|
||||
// Test the bug of Tensor::ToString will exec failed for Tensor which store bool values
|
||||
TEST_F(MindDataTestTensorDE, BoolTensor) {
|
||||
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2}),
|
||||
DataType(DataType::DE_BOOL));
|
||||
t->SetItemAt<bool>({0}, true);
|
||||
t->SetItemAt<bool>({1}, true);
|
||||
std::string out = t->ToString();
|
||||
ASSERT_TRUE(out.find("Template type and Tensor type are not compatible") == std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTensorDE, GetItemAt) {
|
||||
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 2}), DataType(DataType::DE_UINT8));
|
||||
t->Fill<uint8_t>(254);
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"rowsPerBuffer": 10,
|
||||
}
|
|
@ -0,0 +1,504 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as cde
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
# test for predicate
|
||||
def test_diff_predicate_func():
|
||||
def test_filter(predicate_func):
|
||||
transforms = [
|
||||
cde.Decode(),
|
||||
cde.Resize([64, 64])
|
||||
]
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], shuffle=False)
|
||||
dataset = dataset.map(input_columns=["image"], operations=transforms, num_parallel_workers=1)
|
||||
dataset = dataset.filter(input_columns=["image", "label"], predicate=predicate_func, num_parallel_workers=4)
|
||||
|
||||
num_iter = 0
|
||||
label_list = []
|
||||
for data in dataset.create_dict_iterator():
|
||||
num_iter += 1
|
||||
ori_img = data["image"]
|
||||
label = data["label"]
|
||||
label_list.append(label)
|
||||
assert num_iter == 1
|
||||
assert label_list[0] == 3
|
||||
|
||||
test_filter(lambda image, label: label == 3)
|
||||
test_filter(lambda image, label: label[0] == 3)
|
||||
test_filter(lambda image, label: label == [3])
|
||||
test_filter(lambda image, label: label == np.array([3]))
|
||||
test_filter(lambda image, label: label == np.array(3))
|
||||
|
||||
def filter_func_ge(data):
|
||||
if data > 10:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def generator_1d():
|
||||
for i in range(64):
|
||||
yield (np.array(i),)
|
||||
|
||||
# test with GeneratorDataset
|
||||
def test_filter_by_generator_with_no():
|
||||
dataset = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
dataset_f = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
|
||||
num_iter = 0
|
||||
expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
for item in dataset_f.create_dict_iterator():
|
||||
assert item["data"] == expected_rs[num_iter]
|
||||
num_iter += 1
|
||||
|
||||
# test with repeatOp before
|
||||
def test_filter_by_generator_with_repeat():
|
||||
dataset = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
dataset_r = dataset.repeat(4)
|
||||
dataset_f = dataset_r.filter(predicate=filter_func_ge, num_parallel_workers=4)
|
||||
num_iter = 0
|
||||
ret_data = []
|
||||
expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
for item in dataset_f.create_dict_iterator():
|
||||
num_iter += 1
|
||||
ret_data.append(item["data"])
|
||||
assert num_iter == 44
|
||||
for i in range(4):
|
||||
for ii in range(len(expected_rs)):
|
||||
index = i * len(expected_rs) + ii
|
||||
assert ret_data[index] == expected_rs[ii]
|
||||
|
||||
# test with repeatOp after
|
||||
def test_filter_by_generator_with_repeat_after():
|
||||
dataset = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
dataset_f = dataset.filter(predicate=filter_func_ge, num_parallel_workers=4)
|
||||
dataset_r = dataset_f.repeat(4)
|
||||
num_iter = 0
|
||||
ret_data = []
|
||||
expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
for item in dataset_r.create_dict_iterator():
|
||||
num_iter += 1
|
||||
ret_data.append(item["data"])
|
||||
assert num_iter == 44
|
||||
for i in range(4):
|
||||
for ii in range(len(expected_rs)):
|
||||
index = i * len(expected_rs) + ii
|
||||
assert ret_data[index] == expected_rs[ii]
|
||||
|
||||
def filter_func_batch(data):
|
||||
if data[0] > 8:
|
||||
return False
|
||||
return True
|
||||
|
||||
def filter_func_batch_after(data):
|
||||
if data > 20:
|
||||
return False
|
||||
return True
|
||||
|
||||
# test with batchOp before
|
||||
def test_filter_by_generator_with_batch():
|
||||
dataset = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
dataset_b = dataset.batch(4)
|
||||
dataset_f = dataset_b.filter(predicate=filter_func_batch, num_parallel_workers=4)
|
||||
num_iter = 0
|
||||
ret_data = []
|
||||
for item in dataset_f.create_dict_iterator():
|
||||
num_iter += 1
|
||||
ret_data.append(item["data"])
|
||||
assert num_iter == 3
|
||||
assert ret_data[0][0] == 0
|
||||
assert ret_data[1][0] == 4
|
||||
assert ret_data[2][0] == 8
|
||||
|
||||
# test with batchOp after
|
||||
def test_filter_by_generator_with_batch_after():
|
||||
dataset = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
dataset_f = dataset.filter(predicate=filter_func_batch_after, num_parallel_workers=4)
|
||||
dataset_b = dataset_f.batch(4)
|
||||
num_iter = 0
|
||||
ret_data = []
|
||||
for item in dataset_b.create_dict_iterator():
|
||||
num_iter += 1
|
||||
ret_data.append(item["data"])
|
||||
assert num_iter == 6
|
||||
assert ret_data[0][0] == 0
|
||||
assert ret_data[1][0] == 4
|
||||
assert ret_data[5][0] == 20
|
||||
|
||||
|
||||
def filter_func_shuffle(data):
|
||||
if data > 20:
|
||||
return False
|
||||
return True
|
||||
|
||||
# test with batchOp before
|
||||
def test_filter_by_generator_with_shuffle():
|
||||
dataset = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
dataset_s = dataset.shuffle(4)
|
||||
dataset_f = dataset_s.filter(predicate=filter_func_shuffle, num_parallel_workers=4)
|
||||
num_iter = 0
|
||||
for item in dataset_f.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert num_iter == 21
|
||||
|
||||
|
||||
def filter_func_shuffle_after(data):
|
||||
if data > 20:
|
||||
return False
|
||||
return True
|
||||
|
||||
# test with batchOp after
|
||||
def test_filter_by_generator_with_shuffle_after():
|
||||
dataset = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
dataset_f = dataset.filter(predicate=filter_func_shuffle_after, num_parallel_workers=4)
|
||||
dataset_s = dataset_f.shuffle(4)
|
||||
num_iter = 0
|
||||
for item in dataset_s.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert num_iter == 21
|
||||
|
||||
|
||||
def generator_1d_zip1():
|
||||
for i in range(64):
|
||||
yield (np.array(i),)
|
||||
|
||||
|
||||
def generator_1d_zip2():
|
||||
for i in range(64):
|
||||
yield (np.array(i+100),)
|
||||
|
||||
|
||||
def filter_func_zip(data1, data2):
|
||||
if data1 > 20:
|
||||
return False
|
||||
return True
|
||||
|
||||
def filter_func_zip_after(data1):
|
||||
if data1 > 20:
|
||||
return False
|
||||
return True
|
||||
|
||||
# test with zipOp before
|
||||
def test_filter_by_generator_with_zip():
|
||||
dataset1 = ds.GeneratorDataset(generator_1d_zip1, ["data1"])
|
||||
dataset2 = ds.GeneratorDataset(generator_1d_zip2, ["data2"])
|
||||
dataz = ds.zip((dataset1, dataset2))
|
||||
dataset_f = dataz.filter(predicate=filter_func_zip, num_parallel_workers=1)
|
||||
num_iter = 0
|
||||
ret_data = []
|
||||
for item in dataset_f.create_dict_iterator():
|
||||
num_iter += 1
|
||||
ret_data.append({"data1": item["data1"], "data2":item["data2"]})
|
||||
assert num_iter == 21
|
||||
assert ret_data[0]["data1"] == 0
|
||||
assert ret_data[0]["data2"] == 100
|
||||
assert ret_data[5]["data1"] == 5
|
||||
assert ret_data[5]["data2"] == 105
|
||||
|
||||
|
||||
# test with zipOp after
|
||||
def test_filter_by_generator_with_zip_after():
|
||||
dataset1 = ds.GeneratorDataset(generator_1d_zip1, ["data1"])
|
||||
dataset2 = ds.GeneratorDataset(generator_1d_zip1, ["data2"])
|
||||
dt1 = dataset1.filter(predicate=filter_func_zip_after, num_parallel_workers=4)
|
||||
dt2 = dataset2.filter(predicate=filter_func_zip_after, num_parallel_workers=4)
|
||||
dataz = ds.zip((dt1, dt2))
|
||||
num_iter = 0
|
||||
ret_data = []
|
||||
for item in dataz.create_dict_iterator():
|
||||
num_iter += 1
|
||||
ret_data.append({"data1": item["data1"], "data2":item["data2"]})
|
||||
assert num_iter == 21
|
||||
assert ret_data[0]["data1"] == 0
|
||||
assert ret_data[0]["data2"] == 0
|
||||
assert ret_data[5]["data1"] == 5
|
||||
assert ret_data[5]["data2"] == 5
|
||||
|
||||
|
||||
def filter_func_map(col1, col2):
|
||||
if col1[0] > 8:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def filter_func_map_part(col1):
|
||||
if col1 < 3:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def filter_func_map_all(col1, col2):
|
||||
return True
|
||||
|
||||
def generator_mc(maxid=20):
|
||||
for i in range(maxid):
|
||||
yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
|
||||
|
||||
|
||||
def func_map(data_col1, data_col2):
|
||||
return (data_col1, data_col2)
|
||||
|
||||
|
||||
def func_map_part(data_col1):
|
||||
return (data_col1)
|
||||
|
||||
# test with map
|
||||
def test_filter_by_generator_with_map_all_col():
|
||||
dataset = ds.GeneratorDataset(generator_mc(12), ["col1", "col2"])
|
||||
dataset_map = dataset.map( input_columns=["col1"], output_columns=["col1"] , operations=func_map_part)
|
||||
# dataset_map = dataset.map( operations=func_map_part)
|
||||
dataset_f = dataset_map.filter(input_columns=["col1"], predicate=filter_func_map_part, num_parallel_workers=1)
|
||||
num_iter = 0
|
||||
ret_data = []
|
||||
for item in dataset_f.create_dict_iterator():
|
||||
num_iter += 1
|
||||
ret_data.append(item["col1"])
|
||||
assert num_iter == 3
|
||||
assert ret_data[0] == 0
|
||||
assert ret_data[1] == 1
|
||||
|
||||
# test with map
|
||||
def test_filter_by_generator_with_map_part_col():
|
||||
dataset = ds.GeneratorDataset(generator_mc(12), ["col1", "col2"])
|
||||
dataset_map = dataset.map( input_columns=["col1"], output_columns=["out1"] , operations=func_map_part)
|
||||
|
||||
dataset_f = dataset_map.filter(input_columns=["out1", "col2"], predicate=filter_func_map, num_parallel_workers=4)
|
||||
num_iter = 0
|
||||
ret_data = []
|
||||
for item in dataset_f.create_dict_iterator():
|
||||
num_iter += 1
|
||||
print(item)
|
||||
ret_data.append(item["out1"])
|
||||
assert num_iter == 3
|
||||
assert ret_data[0] == 9
|
||||
assert ret_data[2] == 11
|
||||
|
||||
|
||||
def filter_func_rename(data):
|
||||
if data> 8:
|
||||
return True
|
||||
return False
|
||||
|
||||
# test with rename before
|
||||
def test_filter_by_generator_with_rename():
|
||||
dataset = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
dataset_b = dataset.rename(input_columns=["data"], output_columns=["col1"])
|
||||
dataset_f = dataset_b.filter(predicate=filter_func_rename, num_parallel_workers=4)
|
||||
num_iter = 0
|
||||
ret_data = []
|
||||
for item in dataset_f.create_dict_iterator():
|
||||
num_iter += 1
|
||||
ret_data.append(item["col1"])
|
||||
assert num_iter == 55
|
||||
assert ret_data[0] == 9
|
||||
assert ret_data[54] == 63
|
||||
|
||||
|
||||
#test input_column
|
||||
def filter_func_input_column1(col1, col2):
|
||||
if col1[0] < 8:
|
||||
return True
|
||||
return False
|
||||
|
||||
def filter_func_input_column2(col1):
|
||||
if col1[0] < 8:
|
||||
return True
|
||||
return False
|
||||
|
||||
def filter_func_input_column3(col1):
|
||||
return True
|
||||
|
||||
# test with input_columns
|
||||
def test_filter_by_generator_with_input_column():
|
||||
dataset = ds.GeneratorDataset(generator_mc(64), ["col1", "col2"])
|
||||
dataset_map = dataset.map( input_columns=["col1"], output_columns=["out1"] , operations=func_map_part)
|
||||
dataset_f1 = dataset_map.filter(input_columns=["out1", "col2"], predicate=filter_func_input_column1, num_parallel_workers=4)
|
||||
dataset_f2 = dataset_f1.filter(input_columns=["out1"], predicate=filter_func_input_column2, num_parallel_workers=4)
|
||||
dataset_f3 = dataset_f2.filter(input_columns=["col2"], predicate=filter_func_input_column3, num_parallel_workers=4)
|
||||
dataset_f4 = dataset_f3.filter(predicate=filter_func_input_column1, num_parallel_workers=4)
|
||||
num_iter = 0
|
||||
ret_data = []
|
||||
for item in dataset_f4.create_dict_iterator():
|
||||
num_iter += 1
|
||||
ret_data.append(item["out1"])
|
||||
assert num_iter == 8
|
||||
assert ret_data[0] == 0
|
||||
assert ret_data[7] == 7
|
||||
|
||||
|
||||
#test kFilterPartial
|
||||
def generator_mc_p0(maxid=20):
|
||||
for i in range(maxid):
|
||||
yield (np.array([i ]), np.array([i + 100]))
|
||||
|
||||
def generator_mc_p1(maxid=20):
|
||||
for i in range(maxid):
|
||||
yield (np.array([i + 200 ]), np.array([i + 300]))
|
||||
|
||||
|
||||
def filter_func_Partial_0(col1, col2, col3, col4):
|
||||
filter_data = [0,1,2,3,4, 11]
|
||||
if col1[0] in filter_data:
|
||||
return False
|
||||
return True
|
||||
|
||||
# test with row_data_buffer > 1
|
||||
def test_filter_by_generator_Partial0():
|
||||
ds.config.load('../data/dataset/declient_filter.cfg')
|
||||
dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"])
|
||||
dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"])
|
||||
dataset_zip = ds.zip((dataset1, dataset2))
|
||||
dataset_f1 = dataset_zip.filter(predicate=filter_func_Partial_0, num_parallel_workers=2)
|
||||
ret = []
|
||||
for item in dataset_f1.create_dict_iterator():
|
||||
ret.append(item["col1"])
|
||||
assert ret[0] == 5
|
||||
assert ret[6] == 12
|
||||
|
||||
# test with row_data_buffer > 1
|
||||
def test_filter_by_generator_Partial1():
|
||||
ds.config.load('../data/dataset/declient_filter.cfg')
|
||||
dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"])
|
||||
dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"])
|
||||
dataset_zip = ds.zip((dataset1, dataset2))
|
||||
dataset_f1 = dataset_zip.filter(predicate=filter_func_Partial_0, num_parallel_workers=2)
|
||||
dataset_map = dataset_f1.map( input_columns=["col1"], output_columns=["out1"] , operations=lambda x1: x1 + 400)
|
||||
ret = []
|
||||
for item in dataset_map.create_dict_iterator():
|
||||
ret.append(item["out1"])
|
||||
assert ret[0] == 405
|
||||
assert ret[6] == 412
|
||||
|
||||
# test with row_data_buffer > 1
|
||||
def test_filter_by_generator_Partial2():
|
||||
ds.config.load('../data/dataset/declient_filter.cfg')
|
||||
dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"])
|
||||
dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"])
|
||||
|
||||
dataset1f = dataset1.filter( input_columns= ["col1"], predicate=lambda x: x not in [3,7,9], num_parallel_workers=2)
|
||||
dataset2f = dataset2.filter( input_columns= ["col3"], predicate=lambda x: x not in [203,207,209], num_parallel_workers=2)
|
||||
dataset_zip = ds.zip((dataset1f, dataset2f))
|
||||
dataset_map = dataset_zip.map( input_columns=["col1", "col3"], output_columns=["out1", "out3"] , operations=lambda x1,x3: (x1 + 400, x3+500))
|
||||
ret1 = []
|
||||
ret3 = []
|
||||
for item in dataset_map.create_dict_iterator():
|
||||
ret1.append(item["out1"])
|
||||
ret3.append(item["out3"])
|
||||
assert ret1[0] == 400
|
||||
assert ret1[6] == 408
|
||||
assert ret3[0] == 700
|
||||
assert ret3[6] == 708
|
||||
|
||||
|
||||
def filter_func_Partial(col1, col2):
|
||||
if col1[0] % 3 == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def generator_big(maxid=20):
|
||||
for i in range(maxid):
|
||||
yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
|
||||
|
||||
# test with row_data_buffer > 1
|
||||
def test_filter_by_generator_Partial():
|
||||
ds.config.load('../data/dataset/declient_filter.cfg')
|
||||
dataset = ds.GeneratorDataset(source = generator_mc(99), column_names = ["col1", "col2"])
|
||||
dataset_s = dataset.shuffle(4)
|
||||
dataset_f1 = dataset_s.filter(input_columns=["col1", "col2"], predicate=filter_func_Partial, num_parallel_workers=1)
|
||||
|
||||
for item in dataset_f1.create_dict_iterator():
|
||||
assert item["col1"] % 3 == 0
|
||||
|
||||
|
||||
def filter_func_cifar(col1, col2):
|
||||
if col2 % 3 == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
# test with cifar10
|
||||
def test_filte_case_dataset_cifar10():
|
||||
DATA_DIR_10 = "../data/dataset/testCifar10Data"
|
||||
ds.config.load('../data/dataset/declient_filter.cfg')
|
||||
dataset_c = ds.Cifar10Dataset(dataset_dir = DATA_DIR_10, num_samples = 100000, shuffle=False)
|
||||
dataset_f1 = dataset_c.filter(input_columns=["image", "label"], predicate=filter_func_cifar, num_parallel_workers=1)
|
||||
num_iter = 0
|
||||
for item in dataset_f1.create_dict_iterator():
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
assert item["label"] % 3 == 0
|
||||
|
||||
# column id sort
|
||||
|
||||
def generator_sort1(maxid=20):
|
||||
for i in range(maxid):
|
||||
yield (np.array([i]), np.array([i + 100]), np.array([i + 200]))
|
||||
|
||||
def generator_sort2(maxid=20):
|
||||
for i in range(maxid):
|
||||
yield (np.array([i + 300]), np.array([i + 400]), np.array([i + 500]))
|
||||
|
||||
|
||||
def filter_func_part_sort(col1, col2, col3, col4, col5, col6):
|
||||
return True
|
||||
|
||||
def filter_func_map_sort(col1, col2, col3):
|
||||
return (col1, col2, col3)
|
||||
|
||||
def test_filter_by_generator_with_map_all_sort():
|
||||
dataset1 = ds.GeneratorDataset(generator_sort1(10), ["col1", "col2", "col3"])
|
||||
dataset2 = ds.GeneratorDataset(generator_sort2(10), ["col4 ", "col5", "col6"])
|
||||
|
||||
dataz = ds.zip((dataset1, dataset2))
|
||||
dataset_f = dataz.filter(predicate=filter_func_part_sort, num_parallel_workers=1)
|
||||
num_iter = 0
|
||||
ret_data = []
|
||||
for item in dataset_f.create_dict_iterator():
|
||||
num_iter += 1
|
||||
ret_data.append(item)
|
||||
|
||||
assert num_iter == 10
|
||||
assert ret_data[0]["col1"] == 0
|
||||
assert ret_data[9]["col6"] == 509
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_diff_predicate_func()
|
||||
test_filte_case_dataset_cifar10()
|
||||
test_filter_by_generator_Partial0()
|
||||
test_filter_by_generator_Partial1()
|
||||
test_filter_by_generator_Partial2()
|
||||
test_filter_by_generator_with_batch()
|
||||
test_filter_by_generator_with_batch_after()
|
||||
test_filter_by_generator_with_input_column()
|
||||
test_filter_by_generator_with_map_all_col()
|
||||
test_filter_by_generator_with_map_all_sort()
|
||||
test_filter_by_generator_with_map_part_col()
|
||||
test_filter_by_generator_with_no()
|
||||
test_filter_by_generator_with_rename()
|
||||
test_filter_by_generator_with_repeat()
|
||||
test_filter_by_generator_with_repeat_after()
|
||||
test_filter_by_generator_with_shuffle()
|
||||
test_filter_by_generator_with_shuffle_after()
|
||||
test_filter_by_generator_with_zip()
|
||||
test_filter_by_generator_with_zip_after()
|
||||
test_filter_by_generator_Partial()
|
|
@ -25,8 +25,8 @@ COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
|
|||
|
||||
|
||||
def check(project_columns):
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS)
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=project_columns)
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS, shuffle=False)
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=project_columns, shuffle=False)
|
||||
|
||||
for data_actual, data_expected in zip(data1.create_tuple_iterator(project_columns), data2.create_tuple_iterator()):
|
||||
assert len(data_actual) == len(data_expected)
|
||||
|
|
Loading…
Reference in New Issue