forked from mindspore-Ecosystem/mindspore
!508 [Dataset] Adding sync_wait operator for dataset
Merge pull request !508 from EricZ/master
This commit is contained in:
commit
dc0491caf9
|
@ -48,6 +48,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
|
|||
{kMap, &DEPipeline::ParseMapOp},
|
||||
{kFilter, &DEPipeline::ParseFilterOp},
|
||||
{kBatch, &DEPipeline::ParseBatchOp},
|
||||
{kBarrier, &DEPipeline::ParseBarrierOp},
|
||||
{kRepeat, &DEPipeline::ParseRepeatOp},
|
||||
{kSkip, &DEPipeline::ParseSkipOp},
|
||||
{kZip, &DEPipeline::ParseZipOp},
|
||||
|
@ -627,6 +628,30 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>();
|
||||
// Right now barrier should only take num_rows_per_buffer = 1
|
||||
// The reason for this is because having it otherwise can lead to blocking issues
|
||||
// See barrier_op.h for more details
|
||||
(void)builder->SetRowsPerBuffer(1);
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "condition_name") {
|
||||
(void)builder->SetConditionName(ToString(value));
|
||||
} else if (key == "condition_func") {
|
||||
(void)builder->SetConditionFunc(value.cast<py::function>());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<BarrierOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*ptr = op;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
int32_t prefetch_size = 0;
|
||||
if (args.contains("prefetch_size")) {
|
||||
|
|
|
@ -40,6 +40,7 @@ enum OpName {
|
|||
kShuffle,
|
||||
kMindrecord,
|
||||
kBatch,
|
||||
kBarrier,
|
||||
kCache,
|
||||
kRepeat,
|
||||
kSkip,
|
||||
|
@ -115,6 +116,8 @@ class DEPipeline {
|
|||
|
||||
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
|
|
@ -481,6 +481,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
|
|||
.value("STORAGE", OpName::kStorage)
|
||||
.value("SHUFFLE", OpName::kShuffle)
|
||||
.value("BATCH", OpName::kBatch)
|
||||
.value("BARRIER", OpName::kBarrier)
|
||||
.value("MINDRECORD", OpName::kMindrecord)
|
||||
.value("CACHE", OpName::kCache)
|
||||
.value("REPEAT", OpName::kRepeat)
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "dataset/core/tensor_shape.h"
|
||||
#include "dataset/engine/data_schema.h"
|
||||
#include "dataset/engine/dataset_iterator.h"
|
||||
#include "dataset/engine/datasetops/barrier_op.h"
|
||||
#include "dataset/engine/datasetops/batch_op.h"
|
||||
#include "dataset/engine/datasetops/dataset_op.h"
|
||||
#include "dataset/engine/datasetops/device_queue_op.h"
|
||||
|
|
|
@ -4,6 +4,7 @@ add_library(engine-datasetops OBJECT
|
|||
dataset_op.cc
|
||||
parallel_op.cc
|
||||
pipeline_op.cc
|
||||
barrier_op.cc
|
||||
batch_op.cc
|
||||
device_queue_op.cc
|
||||
map_op.cc
|
||||
|
|
|
@ -0,0 +1,235 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/datasetops/barrier_op.h"
|
||||
#include <utility>
|
||||
#include "dataset/core/constants.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/core/global_context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
BarrierOp::Builder::Builder() {
|
||||
// Some arguments to the BarrierOp constructor have a default argument that is taken
|
||||
// from the client config.
|
||||
// The user may choose to change these values for the construction of the BarrierOp by
|
||||
// using the various builder set methods.
|
||||
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
builder_op_connector_size_ = cfg->op_connector_size();
|
||||
}
|
||||
|
||||
Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); }
|
||||
|
||||
Status BarrierOp::Builder::Build(std::shared_ptr<BarrierOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<BarrierOp>(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_,
|
||||
builder_condition_func_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions
|
||||
BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name,
|
||||
py::function condition_func)
|
||||
: PipelineOp(op_connector_size),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
buffer_id_(0),
|
||||
clean_up_(false),
|
||||
eof_(false),
|
||||
condition_name_(condition_name),
|
||||
condition_function_(condition_func) {}
|
||||
|
||||
// destructor
|
||||
BarrierOp::~BarrierOp() {}
|
||||
|
||||
// Entry point for Barrier, called by launch()
|
||||
Status BarrierOp::operator()() {
|
||||
// The children_num_ parameter needs to be put here
|
||||
// Synchronize with TaskManager once the thread is created.
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
// create child iterator, right now this barrier is a pipeline operator
|
||||
int32_t worker_id = 0;
|
||||
int32_t child_idx = 0;
|
||||
child_iterator_ = std::make_unique<ChildIterator>(this, worker_id, child_idx);
|
||||
|
||||
// Loop until eof is true
|
||||
while (!eof_) {
|
||||
// Create new table to put the new tensor rows
|
||||
std::unique_ptr<TensorQTable> curr_table = std::make_unique<TensorQTable>();
|
||||
RETURN_IF_NOT_OK(prepare(curr_table.get()));
|
||||
|
||||
// If an eof got picked up during the above prepare, then we're done
|
||||
if (eof_) {
|
||||
break;
|
||||
}
|
||||
|
||||
// we have to output new buffer with possibly different buffer size, possibly one row
|
||||
while (!clean_up_) {
|
||||
// 1. If a previous loop iteration sent the current table out, then create a new one.
|
||||
|
||||
if (curr_table == nullptr) {
|
||||
curr_table = std::make_unique<TensorQTable>();
|
||||
}
|
||||
|
||||
// 2 fill the table. Note: clean_up mode might get turned on if epoch is finished
|
||||
RETURN_IF_NOT_OK(fillBuffer(curr_table.get()));
|
||||
|
||||
// 3 create and update buffer and send it to the out connector
|
||||
if (!curr_table->empty()) {
|
||||
std::unique_ptr<DataBuffer> curr_buffer = std::make_unique<DataBuffer>(buffer_id_, DataBuffer::kDeBFlagNone);
|
||||
curr_buffer->set_tensor_table(std::move(curr_table));
|
||||
curr_buffer->set_column_name_map(col_name_id_map_);
|
||||
MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols "
|
||||
<< curr_buffer->NumCols() << ", map " << col_name_id_map_.size() << ".";
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer)));
|
||||
buffer_id_++;
|
||||
}
|
||||
}
|
||||
|
||||
// 4 handle drain state.
|
||||
if (clean_up_) {
|
||||
MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal.";
|
||||
// Send the eoe up.
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
|
||||
}
|
||||
}
|
||||
// 5 handle eof
|
||||
// propagate eof here.
|
||||
MS_LOG(INFO) << "Barrier operator got EOF, propagating.";
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Handles preprocessing of the main loop, used when starting new epoch
|
||||
Status BarrierOp::prepare(TensorQTable *const table) {
|
||||
MS_LOG(DEBUG) << "Barrier operator prepares for new epoch.";
|
||||
clean_up_ = false;
|
||||
buffer_id_ = 0;
|
||||
if (table == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table.");
|
||||
}
|
||||
// fill initial row
|
||||
TensorRow new_row = {};
|
||||
// use iterator to get next row and invoke pyfunc wait
|
||||
RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
|
||||
|
||||
// If the first row fetching resulted in eof, then we are done.
|
||||
if (eof_) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (new_row.empty()) {
|
||||
// This epoch is empty
|
||||
return Status::OK();
|
||||
}
|
||||
// Pack this first row into our tensor table
|
||||
// first row we also have to check if we should block
|
||||
RETURN_IF_NOT_OK(blockCond());
|
||||
|
||||
table->push_back(std::move(new_row));
|
||||
// At this point we have 1 row produced, we take the old column map id and use it in the new table
|
||||
// Initializing col_name_id_map_ from the first data buffer.
|
||||
col_name_id_map_ = child_iterator_->col_name_id_map();
|
||||
// the update code below shouldn't do anything bad if the column name already exists.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// fillBuffer always expects a new table to fill
|
||||
Status BarrierOp::fillBuffer(TensorQTable *const table) {
|
||||
if (table == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer.");
|
||||
}
|
||||
TensorRow new_row = {};
|
||||
while (table->size() < static_cast<size_t>(rows_per_buffer_)) {
|
||||
RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
|
||||
// Early exit the loop if we got empty row from any of our child iterations
|
||||
if (new_row.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
// else we got a row so pack it into the tensor table.
|
||||
RETURN_IF_NOT_OK(blockCond());
|
||||
|
||||
table->push_back(std::move(new_row));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// function executes a py_func and blocks until condition becomes true.
|
||||
Status BarrierOp::blockCond() {
|
||||
{
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||
}
|
||||
// we have condition name, however the flexibility is in python today
|
||||
try {
|
||||
// Invoke python function
|
||||
py::object ret_py_obj = condition_function_();
|
||||
// Process the return value
|
||||
if (!py::isinstance<py::bool_>(ret_py_obj)) {
|
||||
return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false");
|
||||
}
|
||||
} catch (const py::error_already_set &e) {
|
||||
return Status(StatusCode::kPyFuncException, e.what());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// fetches next Barrier buffer row
|
||||
Status BarrierOp::getNextTensorRow(TensorRow *new_row) {
|
||||
// iterate over all iterators and generate a row
|
||||
RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row));
|
||||
// add each new row to iterator, check if row is empty, if row from iterator is empty return empty row
|
||||
if (new_row->empty()) {
|
||||
// If we did not get a row from any of the children, then it's the end of an epoch and we can move
|
||||
// to drain state.
|
||||
MS_LOG(INFO) << "Barrier operator child iterator produced empty row.";
|
||||
clean_up_ = true;
|
||||
// If we picked up an eof here, then we are completely done.
|
||||
if ((child_iterator_)->eof_handled()) {
|
||||
MS_LOG(INFO) << "Barrier operator iterator got EOF.";
|
||||
eof_ = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A function that prints info about the Operator
|
||||
void BarrierOp::Print(std::ostream &out, bool show_all) const {
|
||||
// Call base class printer first
|
||||
PipelineOp::Print(out, show_all);
|
||||
out << "\nBarrierOp:\n"
|
||||
<< "\nCondition " << condition_name_ << "\n\n";
|
||||
}
|
||||
|
||||
// overwrite function and handle eof
|
||||
Status BarrierOp::EofReceived(int32_t) {
|
||||
MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// overwrite function and handle eoe
|
||||
Status BarrierOp::EoeReceived(int32_t) {
|
||||
state_ = OpState::kDeOpIdle;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,172 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/engine/dataset_iterator.h"
|
||||
#include "dataset/engine/datasetops/pipeline_op.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Forward declare
|
||||
class DataBuffer;
|
||||
class ExecutionTree;
|
||||
|
||||
// BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has
|
||||
// been received. This signal is given from python layer. The current barrier design respects the
|
||||
// rows per buffer design and will only output a buffer with rows once it has received rows per buffer
|
||||
// signals from python.
|
||||
|
||||
class BarrierOp : public PipelineOp {
|
||||
public:
|
||||
// The nested builder class inside of the BarrierOp is used to help manage all of
|
||||
// the arguments for constructing it. Use the builder by setting each argument
|
||||
// with the provided set methods, and then finally call the build method to execute
|
||||
// the actual construction.
|
||||
|
||||
class Builder {
|
||||
public:
|
||||
// Builder constructor. Creates the builder object.
|
||||
// @note No default args
|
||||
// @return This is a constructor.
|
||||
Builder();
|
||||
|
||||
// Default destructor
|
||||
~Builder() = default;
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
|
||||
builder_rows_per_buffer_ = rows_per_buffer;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @param int32_t op_connector_size
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetOpConnectorSize(int32_t op_connector_size) {
|
||||
builder_op_connector_size_ = op_connector_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @param const std::string & condition_name
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetConditionName(const std::string &condition_name) {
|
||||
builder_condition_name_ = condition_name;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @param py::function condition_func - blocking condition function
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetConditionFunc(py::function condition_func) {
|
||||
builder_condition_func_ = condition_func;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// The builder "build" method creates the BarrierOp dataset Operator.
|
||||
// @return shared_ptr to the new BarrierOp object
|
||||
Status Build(std::shared_ptr<BarrierOp> *);
|
||||
|
||||
private:
|
||||
int32_t builder_rows_per_buffer_;
|
||||
int32_t builder_op_connector_size_;
|
||||
std::string builder_condition_name_;
|
||||
py::function builder_condition_func_;
|
||||
|
||||
Status SanityCheck() const;
|
||||
};
|
||||
|
||||
// Constructor for BarrierOp
|
||||
// @param rows_per_buffer - number of rows in output buffer
|
||||
// @param op_connector_size - connector size
|
||||
// @param condition_name - the condition name associated with this operator
|
||||
// @param condition_func - the blocking condition check per row
|
||||
// @note - currently rows_per_buffer should = 1 for barrier.
|
||||
// The reason for this is having other values would complicate how the pipeline behaves with other operators
|
||||
// One example of such case is having batch after barrier. Batch would be waiting for data and having
|
||||
// rows per buffer in this case can result in hanging
|
||||
BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name,
|
||||
py::function condition_func);
|
||||
|
||||
// Destructor
|
||||
~BarrierOp();
|
||||
|
||||
Status EofReceived(int32_t) override;
|
||||
|
||||
Status EoeReceived(int32_t) override;
|
||||
|
||||
// Print function for Barrier
|
||||
// @param out - output stream to print to
|
||||
// @param show_all - if it should print everything
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// Provide stream operator for displaying it
|
||||
friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) {
|
||||
bo.Print(out, false);
|
||||
return out;
|
||||
}
|
||||
|
||||
// Class functor operator () override.
|
||||
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// Handles preprocessing of the main loop, used when starting new epoch
|
||||
// @param table - a table of tensors to be moved into a buffer
|
||||
Status prepare(TensorQTable *const table);
|
||||
|
||||
// This function calls takes a table repeatedly adds rows to it.
|
||||
// @param table - a table of tensors to be moved into a buffer
|
||||
Status fillBuffer(TensorQTable *const table);
|
||||
|
||||
// Gets next tensor row and sets control signals
|
||||
Status getNextTensorRow(TensorRow *new_row);
|
||||
|
||||
// This function runs the wait function on condition
|
||||
Status blockCond();
|
||||
|
||||
private:
|
||||
// clean up variable to return imcomplete buffer
|
||||
bool clean_up_;
|
||||
// end of file state, we stop reading data and shut down
|
||||
bool eof_;
|
||||
// rows per buffer
|
||||
int32_t rows_per_buffer_;
|
||||
// buffer_id
|
||||
int32_t buffer_id_;
|
||||
// local variable to keep track of the buffer information
|
||||
std::unordered_map<std::string, int32_t> col_name_id_map_;
|
||||
// iterator to pull new rows, we only have one child
|
||||
std::unique_ptr<ChildIterator> child_iterator_;
|
||||
// condition name, to support multiple barriers
|
||||
std::string condition_name_;
|
||||
// Function pointer of blocking function
|
||||
py::function condition_function_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_
|
|
@ -34,7 +34,7 @@ class DataBuffer;
|
|||
|
||||
class ZipOp : public PipelineOp {
|
||||
public:
|
||||
// The nested builder class inside of the BatchOp is used to help manage all of
|
||||
// The nested builder class inside of the ZipOp is used to help manage all of
|
||||
// the arguments for constructing it. Use the builder by setting each argument
|
||||
// with the provided set methods, and then finally call the build method to execute
|
||||
// the actual construction.
|
||||
|
@ -76,8 +76,8 @@ class ZipOp : public PipelineOp {
|
|||
};
|
||||
|
||||
// Constructor for ZipOp
|
||||
// @param rows_per_buffer number of rows in output buffer
|
||||
// @param op_connector_size connector
|
||||
// @param rows_per_buffer - number of rows in output buffer
|
||||
// @param op_connector_size - connector size
|
||||
ZipOp(int32_t rows_per_buffer, int32_t op_connector_size);
|
||||
|
||||
// Destructor
|
||||
|
@ -88,8 +88,8 @@ class ZipOp : public PipelineOp {
|
|||
Status EoeReceived(int32_t) override;
|
||||
|
||||
// Print function for Zip
|
||||
// @param out output stream to print to
|
||||
// @param show_all if it should print everything
|
||||
// @param out - output stream to print to
|
||||
// @param show_all - if it should print everything
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// Provide stream operator for displaying it
|
||||
|
@ -113,14 +113,14 @@ class ZipOp : public PipelineOp {
|
|||
Status fillBuffer(TensorQTable *const table);
|
||||
|
||||
// Special handle case where an empty row has been received from child iterator
|
||||
// @note we need to drain eoe signals from all children connectors.
|
||||
// @details when this function is called, then we encountered eoe at child iterator
|
||||
// @note - we need to drain eoe signals from all children connectors.
|
||||
// @details - when this function is called, then we encountered eoe at child iterator
|
||||
// we have to drain rows from other child iterators until we hit eoe from all other child iterators
|
||||
Status drainPipeline();
|
||||
|
||||
// Merges 1 row from each childIterator together
|
||||
// @param new_zip_row input and output, will return a non-empty row if all rows from childConnectors are non-empty
|
||||
// @param updateColumnMapping generates a new column name to index mapping (mColNameIdMap) if set to true
|
||||
// @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty
|
||||
// @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true
|
||||
// @details merge rows from iterator together. This is the main functionality for ZipOp
|
||||
// this function takes one row and fills it with tensors from rows fetched
|
||||
// from childIterators.
|
||||
|
|
|
@ -28,6 +28,7 @@ import multiprocessing
|
|||
import queue
|
||||
from enum import Enum
|
||||
from importlib import import_module
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
|
||||
|
@ -40,7 +41,7 @@ from .iterators import DictIterator, TupleIterator
|
|||
from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \
|
||||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
|
||||
check_zip_dataset, check_add_column, check_textfiledataset
|
||||
check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
||||
try:
|
||||
|
@ -141,6 +142,7 @@ class Dataset:
|
|||
self._batch_size = None
|
||||
self._num_classes = None
|
||||
self._repeat_count = None
|
||||
self._sync = False
|
||||
|
||||
def get_args(self):
|
||||
"""
|
||||
|
@ -198,6 +200,30 @@ class Dataset:
|
|||
"""
|
||||
return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns)
|
||||
|
||||
@check_sync_wait
|
||||
def sync_wait(self, condition_name, num_batch=1, callback=None):
|
||||
'''
|
||||
Add a blocking condition to the input Dataset
|
||||
|
||||
Args:
|
||||
input_dataset (Dataset): Input dataset to apply flow control
|
||||
num_batch (int): the number of batches without blocking at the start of each epoch
|
||||
condition_name (str): The condition name that is used to toggle sending next row
|
||||
callback (function): The callback funciton that will be invoked when sync_update is called
|
||||
|
||||
Raises:
|
||||
RuntimeError: If condition name already exists.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> # data is an instance of Dataset object.
|
||||
>>> data = data.sync_wait("callback1")
|
||||
>>> data = data.batch(batch_size)
|
||||
>>> for batch_data in data.create_dict_iterator():
|
||||
>>> data = data.sync_update("callback1")
|
||||
'''
|
||||
return SyncWaitDataset(self, condition_name, num_batch, callback)
|
||||
|
||||
@check_shuffle
|
||||
def shuffle(self, buffer_size):
|
||||
"""
|
||||
|
@ -220,6 +246,9 @@ class Dataset:
|
|||
Returns:
|
||||
ShuffleDataset, dataset shuffled.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If exist sync operators before shuffle.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> # data is an instance of Dataset object
|
||||
|
@ -821,6 +850,9 @@ class Dataset:
|
|||
self._input_indexs = value
|
||||
|
||||
def _get_pipeline_info(self):
|
||||
"""
|
||||
Gets pipeline information.
|
||||
"""
|
||||
device_iter = TupleIterator(self)
|
||||
self._output_shapes = device_iter.get_output_shapes()
|
||||
self._output_types = device_iter.get_output_types()
|
||||
|
@ -875,6 +907,30 @@ class Dataset:
|
|||
return self.input[0].num_classes()
|
||||
return None
|
||||
|
||||
def get_sync_notifiers(self):
|
||||
if self.input:
|
||||
return self.input[0].get_sync_notifiers()
|
||||
return {}
|
||||
|
||||
def is_sync(self):
|
||||
if self.input:
|
||||
return self.input[0].is_sync()
|
||||
return False
|
||||
|
||||
def sync_update(self, condition_name, num_batch=None, data=None):
|
||||
"""
|
||||
condition_name (str): The condition name that is used to toggle sending next row
|
||||
step_size (int or None): The number of steps(rows) that are released
|
||||
when pass_rows is None, will update the same number as sync_wait specified
|
||||
data (dict or None): The data passed to the callback
|
||||
"""
|
||||
notifiers_dict = self.get_sync_notifiers()
|
||||
if condition_name not in notifiers_dict:
|
||||
raise RuntimeError("Condition name not found")
|
||||
if num_batch is not None:
|
||||
num_batch *= self.get_batch_size()
|
||||
notifiers_dict[condition_name](num_batch, data)
|
||||
|
||||
def get_batch_size(self):
|
||||
"""
|
||||
Get the size of a batch.
|
||||
|
@ -978,6 +1034,8 @@ class BatchDataset(DatasetOp):
|
|||
if BatchDataset._is_ancestor_of_repeat(input_dataset):
|
||||
logger.warning("Repeat is located before batch, data from two epochs can be batched together.")
|
||||
|
||||
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.drop_remainder = drop_remainder
|
||||
self.per_batch_map = per_batch_map
|
||||
|
@ -1034,6 +1092,20 @@ class BatchDataset(DatasetOp):
|
|||
flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset)
|
||||
return flag
|
||||
|
||||
@staticmethod
|
||||
def _update_batch_size_for_syncwait(dataset, batch_size):
|
||||
"""
|
||||
Utility function to notify batch size to sync_wait.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): dataset to be checked
|
||||
batchsize (int): batch size to notify
|
||||
"""
|
||||
if isinstance(dataset, SyncWaitDataset):
|
||||
dataset.update_sync_batch_size(batch_size)
|
||||
for input_dataset in dataset.input:
|
||||
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
|
||||
|
||||
|
||||
class BatchInfo(CBatchInfo):
|
||||
"""
|
||||
|
@ -1058,6 +1130,108 @@ class BatchInfo(CBatchInfo):
|
|||
"""
|
||||
return
|
||||
|
||||
class BlockReleasePair:
|
||||
"""
|
||||
The blocking condition class used by SyncWaitDataset
|
||||
|
||||
Args:
|
||||
init_release_rows (int): Number of lines to allow through the pipeline
|
||||
callback (function): The callback funciton that will be called when release is called
|
||||
"""
|
||||
def __init__(self, init_release_rows, callback=None):
|
||||
self.row_count = -init_release_rows
|
||||
self.cv = threading.Condition()
|
||||
self.callback = callback
|
||||
self.default_rows = init_release_rows
|
||||
|
||||
def __deepcopy__(self, memodict):
|
||||
if id(self) in memodict:
|
||||
return memodict[id(self)]
|
||||
memodict[id(self)] = self
|
||||
# condition variable and callback are the same, but reset the counter
|
||||
self.reset()
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
with self.cv:
|
||||
self.row_count = -self.default_rows
|
||||
self.cv.notify_all()
|
||||
|
||||
def update_batched_size(self, batch_size):
|
||||
# should only use before the pipeline creates
|
||||
self.row_count *= batch_size
|
||||
self.default_rows *= batch_size
|
||||
|
||||
def block_func(self):
|
||||
with self.cv:
|
||||
self.cv.wait_for(lambda: self.row_count < 0)
|
||||
self.row_count += 1
|
||||
return True
|
||||
|
||||
def release_func(self, pass_rows=None, data=None):
|
||||
with self.cv:
|
||||
if pass_rows is None:
|
||||
pass_rows = self.default_rows
|
||||
self.row_count -= pass_rows
|
||||
if self.callback is not None:
|
||||
self.callback(data)
|
||||
self.cv.notify_all()
|
||||
|
||||
class SyncWaitDataset(DatasetOp):
|
||||
"""
|
||||
The result of adding a blocking condition to the input Dataset
|
||||
|
||||
Args:
|
||||
input_dataset (Dataset): Input dataset to apply flow control
|
||||
num_batch (int): the number of batches without blocking at the start of each epoch
|
||||
condition_name (str): The condition name that is used to toggle sending next row
|
||||
callback (function): The callback funciton that will be invoked when sync_update is called
|
||||
|
||||
Raises:
|
||||
RuntimeError: If condition name already exists.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, condition_name, num_batch, callback=None):
|
||||
super().__init__()
|
||||
self.input.append(input_dataset)
|
||||
input_dataset.output.append(self)
|
||||
# set to the default value, waiting for the batch to update it
|
||||
self._condition_name = condition_name
|
||||
self._pair = BlockReleasePair(num_batch, callback)
|
||||
if self._condition_name in self.input[0].get_sync_notifiers():
|
||||
raise RuntimeError("Condition name is already in use")
|
||||
|
||||
def get_sync_notifiers(self):
|
||||
return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}}
|
||||
|
||||
def is_sync(self):
|
||||
return True
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["condition_name"] = self._condition_name
|
||||
args["condition_func"] = self._pair.block_func
|
||||
return args
|
||||
|
||||
def update_sync_batch_size(self, batch_size):
|
||||
self._pair.update_batched_size(batch_size)
|
||||
|
||||
@staticmethod
|
||||
def _is_ancestor_of_batch(dataset):
|
||||
"""
|
||||
Utility function to find the case where sync_wait is used before batch.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): dataset to be checked
|
||||
Return:
|
||||
True or False
|
||||
"""
|
||||
if isinstance(dataset, BatchDataset):
|
||||
return True
|
||||
flag = False
|
||||
for input_dataset in dataset.input:
|
||||
flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
|
||||
return flag
|
||||
|
||||
class ShuffleDataset(DatasetOp):
|
||||
"""
|
||||
|
@ -1066,6 +1240,9 @@ class ShuffleDataset(DatasetOp):
|
|||
Args:
|
||||
input_dataset (Dataset): Input Dataset to be shuffled.
|
||||
buffer_size (int): The size of the buffer.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If exist sync operators before shuffle.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, buffer_size):
|
||||
|
@ -1074,6 +1251,8 @@ class ShuffleDataset(DatasetOp):
|
|||
self.input.append(input_dataset)
|
||||
input_dataset.output.append(self)
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
if self.is_sync():
|
||||
raise RuntimeError("No shuffle after sync operators")
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
|
@ -1427,6 +1606,9 @@ class ZipDataset(DatasetOp):
|
|||
"""
|
||||
return None
|
||||
|
||||
def is_sync(self):
|
||||
return any([c.is_sync() for c in self.input])
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
return args
|
||||
|
|
|
@ -129,6 +129,8 @@ class Iterator:
|
|||
op_type = OpName.MINDRECORD
|
||||
elif isinstance(dataset, de.BatchDataset):
|
||||
op_type = OpName.BATCH
|
||||
elif isinstance(dataset, de.SyncWaitDataset):
|
||||
op_type = OpName.BARRIER
|
||||
elif isinstance(dataset, de.ZipDataset):
|
||||
op_type = OpName.ZIP
|
||||
elif isinstance(dataset, de.MapDataset):
|
||||
|
|
|
@ -652,6 +652,22 @@ def check_batch(method):
|
|||
|
||||
return new_method
|
||||
|
||||
def check_sync_wait(method):
|
||||
"""check the input arguments of sync_wait."""
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
nreq_param_str = ['condition_name']
|
||||
nreq_param_int = ['step_size']
|
||||
|
||||
check_param_type(nreq_param_int, param_dict, int)
|
||||
|
||||
check_param_type(nreq_param_str, param_dict, str)
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
def check_shuffle(method):
|
||||
"""check the input arguments of shuffle."""
|
||||
|
|
|
@ -12,8 +12,18 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import mindspore.dataset as ds
|
||||
"""
|
||||
Testing configuration manager
|
||||
"""
|
||||
import filecmp
|
||||
import glob
|
||||
import os
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
def test_basic():
|
||||
ds.config.load('../data/dataset/declient.cfg')
|
||||
|
@ -36,6 +46,34 @@ def test_basic():
|
|||
assert ds.config.get_prefetch_size() == 4
|
||||
assert ds.config.get_seed() == 5
|
||||
|
||||
def test_pipeline():
|
||||
"""
|
||||
Test that our configuration pipeline works when we set parameters at dataset interval
|
||||
"""
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
ds.config.set_num_parallel_workers(2)
|
||||
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
|
||||
ds.serialize(data1, "testpipeline.json")
|
||||
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
ds.config.set_num_parallel_workers(4)
|
||||
data2 = data2.map(input_columns=["image"], operations=[vision.Decode(True)])
|
||||
ds.serialize(data2, "testpipeline2.json")
|
||||
|
||||
# check that the generated output is different
|
||||
assert (filecmp.cmp('testpipeline.json', 'testpipeline2.json'))
|
||||
|
||||
# this test passes currently because our num_parallel_workers don't get updated.
|
||||
|
||||
# remove generated jason files
|
||||
file_list = glob.glob('*.json')
|
||||
for f in file_list:
|
||||
try:
|
||||
os.remove(f)
|
||||
except IOError:
|
||||
logger.info("Error while deleting: {}".format(f))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_basic()
|
||||
test_pipeline()
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
|
||||
def gen():
|
||||
for i in range(100):
|
||||
yield np.array(i),
|
||||
|
||||
|
||||
class Augment:
|
||||
def __init__(self, loss):
|
||||
self.loss = loss
|
||||
|
||||
def preprocess(self, input):
|
||||
return input
|
||||
|
||||
def update(self, data):
|
||||
self.loss = data["loss"]
|
||||
|
||||
|
||||
def test_simple_sync_wait():
|
||||
"""
|
||||
Test simple sync wait: test sync in dataset pipeline
|
||||
"""
|
||||
logger.info("test_simple_sync_wait")
|
||||
batch_size = 4
|
||||
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
||||
|
||||
aug = Augment(0)
|
||||
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
||||
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
||||
dataset = dataset.batch(batch_size)
|
||||
|
||||
count = 0
|
||||
for data in dataset.create_dict_iterator():
|
||||
assert (data["input"][0] == count)
|
||||
count += batch_size
|
||||
data = {"loss": count}
|
||||
dataset.sync_update(condition_name="policy", data=data)
|
||||
|
||||
|
||||
def test_simple_shuffle_sync():
|
||||
"""
|
||||
Test simple shuffle sync: test shuffle before sync
|
||||
"""
|
||||
logger.info("test_simple_shuffle_sync")
|
||||
shuffle_size = 4
|
||||
batch_size = 10
|
||||
|
||||
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
||||
|
||||
aug = Augment(0)
|
||||
dataset = dataset.shuffle(shuffle_size)
|
||||
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
||||
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
||||
dataset = dataset.batch(batch_size)
|
||||
|
||||
count = 0
|
||||
for data in dataset.create_dict_iterator():
|
||||
count += 1
|
||||
#time.sleep(0.5)
|
||||
data = {"loss": count}
|
||||
dataset.sync_update(condition_name="policy", data=data)
|
||||
|
||||
|
||||
def test_two_sync():
|
||||
"""
|
||||
Test two sync: dataset pipeline with with two sync_operators
|
||||
"""
|
||||
logger.info("test_two_sync")
|
||||
batch_size = 6
|
||||
|
||||
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
||||
|
||||
aug = Augment(0)
|
||||
# notice that with our design, we need to have step_size = shuffle size
|
||||
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
|
||||
|
||||
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
||||
|
||||
dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches")
|
||||
|
||||
dataset = dataset.batch(batch_size)
|
||||
|
||||
count = 0
|
||||
for data in dataset.create_dict_iterator():
|
||||
count += 1
|
||||
data = {"loss": count}
|
||||
dataset.sync_update(condition_name="every batch", data=data)
|
||||
if count % 2 == 0:
|
||||
dataset.sync_update(condition_name="every 2 batches")
|
||||
|
||||
def test_sync_epoch():
|
||||
"""
|
||||
Test sync wait with epochs: test sync with epochs in dataset pipeline
|
||||
"""
|
||||
logger.info("test_sync_epoch")
|
||||
batch_size = 30
|
||||
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
||||
|
||||
aug = Augment(0)
|
||||
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
||||
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
|
||||
for epochs in range(3):
|
||||
aug.update({"loss": 0})
|
||||
count = 0
|
||||
for data in dataset.create_dict_iterator():
|
||||
assert (data["input"][0] == count)
|
||||
count += batch_size
|
||||
data = {"loss": count}
|
||||
dataset.sync_update(condition_name="policy", data=data)
|
||||
|
||||
|
||||
def test_sync_exception_01():
|
||||
"""
|
||||
Test sync: with shuffle in sync mode
|
||||
"""
|
||||
logger.info("test_sync_exception_01")
|
||||
shuffle_size = 4
|
||||
batch_size = 10
|
||||
|
||||
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
||||
|
||||
aug = Augment(0)
|
||||
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
||||
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
||||
|
||||
try:
|
||||
dataset = dataset.shuffle(shuffle_size)
|
||||
except BaseException as e:
|
||||
assert "shuffle" in str(e)
|
||||
dataset = dataset.batch(batch_size)
|
||||
|
||||
|
||||
def test_sync_exception_02():
|
||||
"""
|
||||
Test sync: with duplicated condition name
|
||||
"""
|
||||
logger.info("test_sync_exception_02")
|
||||
batch_size = 6
|
||||
|
||||
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
||||
|
||||
aug = Augment(0)
|
||||
# notice that with our design, we need to have step_size = shuffle size
|
||||
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
|
||||
|
||||
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
||||
|
||||
try:
|
||||
dataset = dataset.sync_wait(num_batch=2, condition_name="every batch")
|
||||
except BaseException as e:
|
||||
assert "name" in str(e)
|
||||
dataset = dataset.batch(batch_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_simple_sync_wait()
|
||||
test_simple_shuffle_sync()
|
||||
test_two_sync()
|
||||
test_sync_exception_01()
|
||||
test_sync_exception_02()
|
||||
test_sync_epoch()
|
Loading…
Reference in New Issue