!3212 GetDatasize feature

Merge pull request !3212 from anzhengqi/epochs-ready
This commit is contained in:
mindspore-ci-bot 2020-07-20 14:18:27 +08:00 committed by Gitee
commit 8e4c0a9d93
94 changed files with 5260 additions and 397 deletions

View File

@ -25,6 +25,8 @@
#include "minddata/dataset/engine/dataset_iterator.h" #include "minddata/dataset/engine/dataset_iterator.h"
#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h" #include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h" #include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/filter_op.h" #include "minddata/dataset/engine/datasetops/filter_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h" #include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h" #include "minddata/dataset/engine/datasetops/source/cifar_op.h"
@ -84,7 +86,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
{kRandomData, &DEPipeline::ParseRandomDataOp}, {kRandomData, &DEPipeline::ParseRandomDataOp},
{kTextFile, &DEPipeline::ParseTextFileOp}, {kTextFile, &DEPipeline::ParseTextFileOp},
{kBuildVocab, &DEPipeline::ParseBuildVocabOp}, {kBuildVocab, &DEPipeline::ParseBuildVocabOp},
{kClue, &DEPipeline::ParseClueOp}}; {kClue, &DEPipeline::ParseClueOp},
{kEpochCtrl, &DEPipeline::ParseEpochCtrlOp}};
DEPipeline::DEPipeline() : iterator_(nullptr) { DEPipeline::DEPipeline() : iterator_(nullptr) {
try { try {
@ -166,8 +169,8 @@ Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &
Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); } Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); }
// Function to launch the tree execution. // Function to launch the tree execution.
Status DEPipeline::LaunchTreeExec() { Status DEPipeline::LaunchTreeExec(const int32_t num_epochs) {
RETURN_IF_NOT_OK(tree_->Prepare()); RETURN_IF_NOT_OK(tree_->Prepare(num_epochs));
RETURN_IF_NOT_OK(tree_->Launch()); RETURN_IF_NOT_OK(tree_->Launch());
iterator_ = std::make_unique<DatasetIterator>(tree_); iterator_ = std::make_unique<DatasetIterator>(tree_);
if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator."); if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator.");
@ -252,6 +255,16 @@ int DEPipeline::GetRepeatCount() const { return repeat_num_; }
float ToFloat(const py::handle &handle) { return py::reinterpret_borrow<py::float_>(handle); } float ToFloat(const py::handle &handle) { return py::reinterpret_borrow<py::float_>(handle); }
Status DEPipeline::StopSend() {
// tree_.root() must be DeviceQueueOp
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_->root().get());
if (op == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "StopSend only supported by DeviceQueueOp");
}
op->StopSend();
return Status::OK();
}
int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); } int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); } bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }
@ -804,6 +817,18 @@ Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp>
return Status::OK(); return Status::OK();
} }
Status DEPipeline::ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom) {
if (args["count"].is_none()) {
std::string err_msg = "Error: count is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::shared_ptr<EpochCtrlOp> op;
RETURN_IF_NOT_OK(EpochCtrlOp::Builder(ToInt(args["count"])).Build(&op));
*top = op;
return Status::OK();
}
Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom) { std::shared_ptr<DatasetOp> *bottom) {
std::shared_ptr<GeneratorOp::Builder> builder = std::make_shared<GeneratorOp::Builder>(); std::shared_ptr<GeneratorOp::Builder> builder = std::make_shared<GeneratorOp::Builder>();
@ -973,8 +998,8 @@ Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<Data
(void)builder->SetDeviceType(ToString(value)); (void)builder->SetDeviceType(ToString(value));
} else if (key == "device_id") { } else if (key == "device_id") {
(void)builder->SetDeviceId(ToInt(value)); (void)builder->SetDeviceId(ToInt(value));
} else if (key == "num_batch") { } else if (key == "send_epoch_end") {
(void)builder->SetNumBatch(ToInt(value)); (void)builder->SetSendEpochEnd(ToBool(value));
} }
} }
} }

View File

@ -70,7 +70,8 @@ enum OpName {
kRandomData, kRandomData,
kTextFile, kTextFile,
kBuildVocab, kBuildVocab,
kClue kClue,
kEpochCtrl
}; };
// The C++ binder class that we expose to the python script. // The C++ binder class that we expose to the python script.
@ -90,7 +91,7 @@ class DEPipeline {
Status AssignRootNode(const DsOpPtr &dataset_op); Status AssignRootNode(const DsOpPtr &dataset_op);
// Function to launch the tree execution. // Function to launch the tree execution.
Status LaunchTreeExec(); Status LaunchTreeExec(int32_t num_epochs);
// Get a row of data as dictionary of column name to the value. // Get a row of data as dictionary of column name to the value.
Status GetNextAsMap(py::dict *output); Status GetNextAsMap(py::dict *output);
@ -143,6 +144,10 @@ class DEPipeline {
Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom); std::shared_ptr<DatasetOp> *bottom);
Status ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
@ -189,6 +194,8 @@ class DEPipeline {
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status StopSend();
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
private: private:

View File

@ -159,7 +159,7 @@ void bindDEPipeline(py::module *m) {
[](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); }) [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
.def("SetBatchParameters", .def("SetBatchParameters",
[](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); }) [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
.def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); }) .def("LaunchTreeExec", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.LaunchTreeExec(num_epochs)); })
.def("GetNextAsMap", .def("GetNextAsMap",
[](DEPipeline &de) { [](DEPipeline &de) {
py::dict out; py::dict out;
@ -188,6 +188,7 @@ void bindDEPipeline(py::module *m) {
.def("GetBatchSize", &DEPipeline::GetBatchSize) .def("GetBatchSize", &DEPipeline::GetBatchSize)
.def("GetNumClasses", &DEPipeline::GetNumClasses) .def("GetNumClasses", &DEPipeline::GetNumClasses)
.def("GetRepeatCount", &DEPipeline::GetRepeatCount) .def("GetRepeatCount", &DEPipeline::GetRepeatCount)
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) { .def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
THROW_IF_ERROR(de.SaveDataset(file_names, file_type)); THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
return true; return true;
@ -999,7 +1000,8 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("BUILDVOCAB", OpName::kBuildVocab) .value("BUILDVOCAB", OpName::kBuildVocab)
.value("CELEBA", OpName::kCelebA) .value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile) .value("TEXTFILE", OpName::kTextFile)
.value("CLUE", OpName::kClue); .value("CLUE", OpName::kClue)
.value("EPOCHCTRL", OpName::kEpochCtrl);
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic()) (void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
.value("DE_JIEBA_MIX", JiebaMode::kMix) .value("DE_JIEBA_MIX", JiebaMode::kMix)

View File

@ -40,7 +40,9 @@ Status IteratorBase::GetNextAsMap(TensorMap *out_map) {
out_map->clear(); out_map->clear();
TensorRow curr_row; TensorRow curr_row;
MS_LOG(INFO) << "get next as map start.";
RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row)); RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row));
MS_LOG(INFO) << "fetchNextTensor success.";
// Return empty map if there's no data // Return empty map if there's no data
if (curr_row.empty()) { if (curr_row.empty()) {
@ -105,7 +107,8 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you // Once eof is handled, always return empty row. Class must be destroyed and recreated if you
// want to iterate again. // want to iterate again.
if (eof_handled_) { if (eof_handled_) {
return Status::OK(); std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
RETURN_STATUS_UNEXPECTED(err);
} }
// Check if we need to get a new DataBuffer to iterate. // Check if we need to get a new DataBuffer to iterate.
@ -119,36 +122,22 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
// Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually // Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually
// handle eoe and eof messages here. // handle eoe and eof messages here.
// //
// An eoe buffer means we have iterated fully to the end of the tree. // An eoe buffer means we have iterated an epoch.
// An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of // The next buffer in the pipeline might be an EOF or a databuffer for next epoch
// all operators.
if (curr_buffer_->eoe()) { if (curr_buffer_->eoe()) {
MS_LOG(DEBUG) << "End of data iteration. Fetch eof and then return empty row."; MS_LOG(INFO) << "End of data iteration.";
curr_buffer_.reset(); // explicitly free the eoe buffer
// Before returning the last empty vector, fetch the eof buffer which should be the last
// buffer, and then free it.
RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
if (!curr_buffer_->eof()) {
RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!");
}
eof_handled_ = true;
curr_buffer_.reset(); // explicitly free the eof buffer
// Set tree to Finished state
root_->Tree()->SetFinished();
return Status::OK(); return Status::OK();
} }
// An eof buffer means it is the end of execution and all operators are shutting down.
// Because there is no more data to return to the caller, this will change `eof_handled_` state and
// returns status unexpected error.
if (curr_buffer_->eof()) { if (curr_buffer_->eof()) {
// An eof by itself, without being preceded by an eoe, is possible if a repeat operator
// exists below us in the stack. Repeat operator eats eoe's but eventually allows the
// flow of an eof up the pipeline by itself.
eof_handled_ = true; eof_handled_ = true;
curr_buffer_.reset(); // explicitly free the eof buffer curr_buffer_.reset(); // explicitly free the eof buffer
// Set tree to Finished state std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
root_->Tree()->SetFinished(); RETURN_STATUS_UNEXPECTED(err);
return Status::OK();
} }
} }
@ -208,20 +197,24 @@ Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) {
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you // Once eof is handled, always return empty row. Class must be destroyed and recreated if you
// want to iterate again. // want to iterate again.
if (eof_handled_) { if (eof_handled_) {
return Status::OK(); std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
RETURN_STATUS_UNEXPECTED(err);
} }
// Check if we need to get a new DataBuffer to iterate. // Check if we need to get a new DataBuffer to iterate.
if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) { if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
// GetNextInput() depends on current_op's EoeReceived. So, EOE buffer might be already be handled and
// this child iterator might not see EOE buffer.
RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_)); RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
// Unlike the DatasetIterator, this child iterator does not quit after eoe. // If an eoe is picked up here, we simply return an empty vector and it's up to the
// Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the
// caller to decide what it wants to do next. // caller to decide what it wants to do next.
if (curr_buffer_->eoe()) { if (curr_buffer_->eoe()) {
MS_LOG(DEBUG) << "Child iterator picked up EOE."; MS_LOG(DEBUG) << "Child iterator picked up EOE.";
end_epoch_ = true; end_epoch_ = true;
return Status::OK(); return Status::OK();
} else {
end_epoch_ = false;
} }
if (curr_buffer_->eof()) { if (curr_buffer_->eof()) {

View File

@ -144,6 +144,9 @@ class ChildIterator : public IteratorBase {
// @return The string to column id mapping. // @return The string to column id mapping.
std::unordered_map<std::string, int32_t> GetColumnNameMap() const override; std::unordered_map<std::string, int32_t> GetColumnNameMap() const override;
// Return T/F if end of epoch
bool end_of_epoch() { return end_epoch_; }
private: private:
DatasetOp *current_op_; // The parent operator. We consume from it's children. DatasetOp *current_op_; // The parent operator. We consume from it's children.
int32_t child_idx_; // The specific child this iterator will fetch from. int32_t child_idx_; // The specific child this iterator will fetch from.

View File

@ -18,6 +18,7 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES
shuffle_op.cc shuffle_op.cc
zip_op.cc zip_op.cc
concat_op.cc concat_op.cc
epoch_ctrl_op.cc
cache_base_op.cc cache_base_op.cc
cache_lookup_op.cc cache_lookup_op.cc
cache_op.cc cache_op.cc

View File

@ -17,11 +17,13 @@
#include "minddata/dataset/engine/datasetops/build_vocab_op.h" #include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include <algorithm> #include <algorithm>
#include <iomanip>
#include <limits> #include <limits>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -202,5 +204,29 @@ BuildVocabOp::Builder::Builder()
builder_num_workers_ = cfg->num_parallel_workers(); builder_num_workers_ = cfg->num_parallel_workers();
builder_connector_size_ = cfg->op_connector_size(); builder_connector_size_ = cfg->op_connector_size();
} }
// A print method typically used for debugging
void BuildVocabOp::Print(std::ostream &out, bool show_all) const {
// Always show the id and name as first line regardless if this summary or detailed print
out << "(" << std::setw(2) << operator_id_ << ") <BuildVocabOp>:";
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nCode is needed here to show more info about the op."
<< "\n\n";
}
}
// Pre-Visitor accept method for NodePass
Status BuildVocabOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->PreRunOnNode(shared_from_base<BuildVocabOp>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -131,6 +131,21 @@ class BuildVocabOp : public ParallelOp {
~BuildVocabOp() = default; ~BuildVocabOp() = default;
/// \brief A print method typically used for debugging
/// \param[out] out The output stream to write output to
/// \param[in] show_all A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
/// \briefStream output operator overload
/// \notes This allows you to write the debug print info using stream operators
/// \param[out] out Reference to the output stream being overloaded
/// \param[in] vop - reference to the BuildVocabOp to display
/// \return - the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const BuildVocabOp &vop) {
vop.Print(out, false);
return out;
}
Status WorkerEntry(int32_t worker_id) override; Status WorkerEntry(int32_t worker_id) override;
// collect the work product from each worker // collect the work product from each worker
@ -152,6 +167,12 @@ class BuildVocabOp : public ParallelOp {
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); } Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); }
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
private: private:
const int32_t interval_; const int32_t interval_;
bool special_first_; bool special_first_;

View File

@ -96,7 +96,7 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id));
} }
} }
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr))); RETURN_IF_NOT_OK(EofReceived(worker_id));
return Status::OK(); return Status::OK();
} }
Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
@ -298,5 +298,19 @@ Status CacheMergeOp::EoeReceived(int32_t worker_id) {
} }
return Status::OK(); return Status::OK();
} }
// Base-class override for handling cases when an eof is received.
Status CacheMergeOp::EofReceived(int32_t worker_id) {
// If we are not in a repeated path, then the merge op gets a eof by itself, without first
// getting an eoe. However, the logic demands that all epochs close with an eoe first before eof.
// Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class
// provides that for us.
if (!BitTest(op_ctrl_flags_, kDeOpRepeated)) {
MS_LOG(DEBUG) << "Cache merge sending eoe";
RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id));
}
MS_LOG(DEBUG) << "Cache merge sending eof";
return DatasetOp::EofReceived(worker_id);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -176,6 +176,11 @@ class CacheMergeOp : public ParallelOp {
/// \return Status object /// \return Status object
Status EoeReceived(int32_t worker_id) override; Status EoeReceived(int32_t worker_id) override;
/// \brief Base-class override for handling cases when an eof is received.
/// \param worker_id - The worker id
/// \return Status - The error code return
Status EofReceived(int32_t worker_id) override;
protected: protected:
Status ComputeColMap() override; Status ComputeColMap() override;

View File

@ -26,6 +26,7 @@
#include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h" #include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/db_connector.h" #include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/opt/pass.h"
@ -102,6 +103,15 @@ Status DatasetOp::InsertAsParent(std::shared_ptr<DatasetOp> to_add) {
} }
return Status::OK(); return Status::OK();
} }
// Removes child operator in this operator.
Status DatasetOp::RemoveChildren() {
for (const auto &child : child_) {
child->RemoveParent(this);
}
child_.clear();
return Status::OK();
}
// Adds a parent operator to this operator // Adds a parent operator to this operator
void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); } void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); }
@ -185,6 +195,12 @@ void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const {
} }
} }
// Getter function to get all of our children.
std::vector<std::shared_ptr<DatasetOp>> DatasetOp::children() const { return child_; }
// Getter function to get all of our parents.
std::vector<DatasetOp *> DatasetOp::parents() const { return parent_; }
// Creates the connector within this operator // Creates the connector within this operator
void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) { void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) {
MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers

View File

@ -76,6 +76,9 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Status eerror code returned /// \return Status eerror code returned
Status Remove(); Status Remove();
// Removes child operator in this operator.
Status RemoveChildren();
/// \brief Getter function to get a shared pointer to our child /// \brief Getter function to get a shared pointer to our child
/// \param[in] child_index An operator can have n children. Indicates which child to return. /// \param[in] child_index An operator can have n children. Indicates which child to return.
/// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index /// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index
@ -86,6 +89,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \param[in] parent_index An operator can have n parents. Indicates which parent to return. /// \param[in] parent_index An operator can have n parents. Indicates which parent to return.
void Parent(DatasetOp **parent, int32_t parent_index) const; void Parent(DatasetOp **parent, int32_t parent_index) const;
// Getter function to get all of our children.
std::vector<std::shared_ptr<DatasetOp>> children() const;
// Getter function to get all of our parents.
std::vector<DatasetOp *> parents() const;
// Inserts a operator as the parent current op. // Inserts a operator as the parent current op.
// Inserted op will become the sole parent of the current op. // Inserted op will become the sole parent of the current op.
// The existing parent of the current op will be transferred to the inserted op. // The existing parent of the current op will be transferred to the inserted op.

View File

@ -25,19 +25,21 @@
#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/perf/profiling.h" #include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/perf/device_queue_tracing.h" #include "minddata/dataset/engine/perf/device_queue_tracing.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/task_manager.h" #include "minddata/dataset/util/task_manager.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
int32_t op_connector_size, int64_t num_batch) int32_t op_connector_size, bool send_epoch_end)
: PipelineOp(op_connector_size), : PipelineOp(op_connector_size),
channel_name_(channel_name), channel_name_(channel_name),
device_type_(device_type), device_type_(device_type),
device_id_(device_id), device_id_(device_id),
prefetch_size_(prefetch_size), prefetch_size_(prefetch_size),
num_batch_(num_batch) {} send_epoch_end_(send_epoch_end),
stop_send_(false) {}
DeviceQueueOp::~DeviceQueueOp() {} DeviceQueueOp::~DeviceQueueOp() {}
@ -53,8 +55,7 @@ DeviceQueueOp::Builder::Builder(int32_t prefetch_size)
: builder_prefetch_size_(prefetch_size), : builder_prefetch_size_(prefetch_size),
builder_device_id_(0), builder_device_id_(0),
builder_device_type_(DeviceType::CPU), builder_device_type_(DeviceType::CPU),
builder_channel_name_(""), builder_channel_name_("") {
builder_num_batch_(0) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_op_connector_size_ = cfg->op_connector_size(); builder_op_connector_size_ = cfg->op_connector_size();
} }
@ -64,6 +65,18 @@ Status DeviceQueueOp::EoeReceived(int32_t worker_id) {
return Status::OK(); return Status::OK();
} }
Status DeviceQueueOp::CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const {
// this method checks if the buffer meets the conditions to be sent to TDT
if (buffer->NumRows() != 0) {
TensorRow row;
buffer->GetRow(0, &row);
for (const auto &item : row) {
CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device.");
}
}
return Status::OK();
}
Status DeviceQueueOp::operator()() { Status DeviceQueueOp::operator()() {
TaskManager::FindMe()->Post(); TaskManager::FindMe()->Post();
@ -82,23 +95,10 @@ Status DeviceQueueOp::operator()() {
return Status::OK(); return Status::OK();
} }
Status DeviceQueueOp::CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const {
// this method checks if the buffer meets the conditions to be sent to TDT
if (buffer->NumRows() != 0) {
TensorRow row;
buffer->GetRow(0, &row);
for (const auto &item : row) {
CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device.");
}
}
return Status::OK();
}
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
Status DeviceQueueOp::SendDataToAscend() { Status DeviceQueueOp::SendDataToAscend() {
MS_LOG(INFO) << "Device queue, sending data to Ascend."; MS_LOG(INFO) << "Device queue, sending data to Ascend.";
int64_t total_batch = 0; int64_t total_batch = 0;
bool is_break_loop = false;
double batch_start_time, end_time; double batch_start_time, end_time;
int32_t batch_cost, tdt_cost; int32_t batch_cost, tdt_cost;
int32_t connector_size = 0; int32_t connector_size = 0;
@ -115,15 +115,20 @@ Status DeviceQueueOp::SendDataToAscend() {
std::unique_ptr<DataBuffer> current_buffer; std::unique_ptr<DataBuffer> current_buffer;
RETURN_IF_NOT_OK(GetNextInput(&current_buffer)); RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
while (!current_buffer->eof() && !is_break_loop) { while (!current_buffer->eof()) {
while (!current_buffer->eoe() && !is_break_loop) { while (!current_buffer->eoe()) {
RETURN_IF_NOT_OK(CheckExceptions(current_buffer)); RETURN_IF_NOT_OK(CheckExceptions(current_buffer));
TensorRow currRow; TensorRow currRow;
for (int row_id = 0; row_id < current_buffer->NumRows() && !is_break_loop; row_id++) { for (int row_id = 0; row_id < current_buffer->NumRows(); row_id++) {
RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow)); RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow));
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost); auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost);
if (status == TdtStatus::FAILED) { if (status == TdtStatus::FAILED) {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); if (stop_send_) {
MS_LOG(INFO) << "stop_send received";
return Status::OK();
} else {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
}
} }
if (isProfilingEnable) { if (isProfilingEnable) {
@ -140,9 +145,6 @@ Status DeviceQueueOp::SendDataToAscend() {
profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, total_batch + 1, connector_size); profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, total_batch + 1, connector_size);
} }
total_batch++; total_batch++;
if (num_batch_ > 0 && total_batch == num_batch_) {
is_break_loop = true;
}
} }
if (isProfilingEnable) { if (isProfilingEnable) {
connector_size = ChildOpConnectorSize(); connector_size = ChildOpConnectorSize();
@ -150,6 +152,19 @@ Status DeviceQueueOp::SendDataToAscend() {
} }
RETURN_IF_NOT_OK(GetNextInput(&current_buffer)); RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
} }
if (current_buffer->eoe() && send_epoch_end_) {
TensorRow currRow;
auto status =
tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE);
if (status == TdtStatus::FAILED) {
if (stop_send_) {
MS_LOG(INFO) << "stop_send received";
return Status::OK();
} else {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
}
}
}
if (isProfilingEnable) { if (isProfilingEnable) {
connector_size = ChildOpConnectorSize(); connector_size = ChildOpConnectorSize();
connector_capacity = ChildOpConnectorCapacity(); connector_capacity = ChildOpConnectorCapacity();
@ -158,7 +173,7 @@ Status DeviceQueueOp::SendDataToAscend() {
} }
tree_->SetFinished(); tree_->SetFinished();
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; MS_LOG(INFO) << "Device queue total batch is " << total_batch;
return Status::OK(); return Status::OK();
} }
@ -196,9 +211,6 @@ Status DeviceQueueOp::SendDataToGPU() {
} }
RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle)); RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle));
total_batch++; total_batch++;
if (num_batch_ > 0 && total_batch == num_batch_) {
is_break_loop = true;
}
} }
if (!TaskManager::FindMe()->Interrupted()) if (!TaskManager::FindMe()->Interrupted())
RETURN_IF_NOT_OK(GetNextInput(&current_buffer)); RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
@ -211,12 +223,10 @@ Status DeviceQueueOp::SendDataToGPU() {
is_break_loop = true; is_break_loop = true;
} }
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; MS_LOG(INFO) << "Device queue total batch is " << total_batch << ".";
GpuBufferMgr::GetInstance().Close(handle); GpuBufferMgr::GetInstance().Close(handle);
GpuBufferMgr::GetInstance().CloseConfirm(); GpuBufferMgr::GetInstance().CloseConfirm();
return Status::OK(); return Status::OK();
} }
@ -240,8 +250,11 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con
if (ret == BlockQueueStatus_T::ERROR_INPUT) { if (ret == BlockQueueStatus_T::ERROR_INPUT) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it."); return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it.");
} else { } else {
MS_LOG(WARNING) << "Retry pushing data..."; if (!stop_send_) {
continue; MS_LOG(WARNING) << "Retry pushing data...";
continue;
}
break;
} }
} else { } else {
break; break;
@ -283,13 +296,11 @@ Status DeviceQueueOp::SendDataToCPU() {
MS_LOG(DEBUG) << "Feature size is " << curr_row[0]->SizeInBytes() << "."; MS_LOG(DEBUG) << "Feature size is " << curr_row[0]->SizeInBytes() << ".";
MS_LOG(DEBUG) << "Label size is " << curr_row[1]->SizeInBytes() << "."; MS_LOG(DEBUG) << "Label size is " << curr_row[1]->SizeInBytes() << ".";
total_batch++; total_batch++;
if (num_batch_ > 0 && total_batch == num_batch_) { if (stop_send_) break;
break;
}
} }
} }
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; MS_LOG(INFO) << "Device queue total batch is " << total_batch << ".";
return Status::OK(); return Status::OK();
} }

View File

@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include "minddata/dataset/engine/datasetops/pipeline_op.h" #include "minddata/dataset/engine/datasetops/pipeline_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
@ -84,8 +85,8 @@ class DeviceQueueOp : public PipelineOp {
return *this; return *this;
} }
Builder &SetNumBatch(int64_t num_batch) { Builder &SetSendEpochEnd(bool send_epoch_end) {
builder_num_batch_ = num_batch; builder_send_epoch_end_ = send_epoch_end;
return *this; return *this;
} }
@ -94,8 +95,9 @@ class DeviceQueueOp : public PipelineOp {
// to call this Build() method. It will instantiate the DeviceQueueOp // to call this Build() method. It will instantiate the DeviceQueueOp
// and return it to caller as a shared pointer. // and return it to caller as a shared pointer.
Status Build(std::shared_ptr<DeviceQueueOp> *ptr) { Status Build(std::shared_ptr<DeviceQueueOp> *ptr) {
*ptr = std::make_shared<DeviceQueueOp>(builder_channel_name_, builder_device_type_, builder_device_id_, *ptr =
builder_prefetch_size_, builder_op_connector_size_, builder_num_batch_); std::make_shared<DeviceQueueOp>(builder_channel_name_, builder_device_type_, builder_device_id_,
builder_prefetch_size_, builder_op_connector_size_, builder_send_epoch_end_);
return Status::OK(); return Status::OK();
} }
@ -104,14 +106,14 @@ class DeviceQueueOp : public PipelineOp {
int32_t builder_device_id_; int32_t builder_device_id_;
DeviceType builder_device_type_; DeviceType builder_device_type_;
std::string builder_channel_name_; std::string builder_channel_name_;
int64_t builder_num_batch_;
int32_t builder_op_connector_size_; int32_t builder_op_connector_size_;
bool builder_send_epoch_end_;
}; };
// Name: constructor // Name: constructor
// Description // Description
DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
int32_t op_connector_size, int64_t num_batch); int32_t op_connector_size, bool send_epoch_end);
// Name: destructor // Name: destructor
// Description // Description
@ -121,6 +123,8 @@ class DeviceQueueOp : public PipelineOp {
const int32_t get_prefetch_size() { return prefetch_size_; } const int32_t get_prefetch_size() { return prefetch_size_; }
void StopSend() { stop_send_ = true; }
// Name: Print() // Name: Print()
// Description: A function that prints info about the node // Description: A function that prints info about the node
void Print(std::ostream &out, // In: The output stream to print to void Print(std::ostream &out, // In: The output stream to print to
@ -149,6 +153,7 @@ class DeviceQueueOp : public PipelineOp {
// Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp // Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp
Status CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const; Status CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const;
private:
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
Status SendDataToAscend(); Status SendDataToAscend();
#endif #endif
@ -164,7 +169,8 @@ class DeviceQueueOp : public PipelineOp {
DeviceType device_type_; DeviceType device_type_;
const int32_t device_id_; const int32_t device_id_;
const int32_t prefetch_size_; const int32_t prefetch_size_;
const int64_t num_batch_; const bool send_epoch_end_;
bool stop_send_;
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
std::shared_ptr<TdtPlugin> tdtInstancePtr; std::shared_ptr<TdtPlugin> tdtInstancePtr;

View File

@ -0,0 +1,130 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iomanip>
#include <iostream>
#include <utility>
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
// The builder "build" method creates the final object.
Status EpochCtrlOp::Builder::Build(std::shared_ptr<EpochCtrlOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<EpochCtrlOp>(build_max_repeats_);
return Status::OK();
}
// Constructor
EpochCtrlOp::EpochCtrlOp(int32_t num_epoch) : RepeatOp(num_epoch) { MS_LOG(INFO) << "Welcome to Epoch Ctrl Op."; }
// Destructor
EpochCtrlOp::~EpochCtrlOp() {}
// A print method typically used for debugging
void EpochCtrlOp::Print(std::ostream &out, bool show_all) const {
// Always show the id and name as first line regardless if this summary or detailed print
out << "(" << std::setw(2) << operator_id_ << ") <EpochCtrlOp>:";
if (!show_all) {
// Call the super class for displaying any common 1-liner info
PipelineOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << " [epochs: " << max_repeats_ << "]\n";
} else {
// Call the super class for displaying any common detailed info
PipelineOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << max_repeats_
<< "\nLeaf Nodes in execution path:";
if (!eoe_ops_.empty()) {
for (size_t i = 0; i < eoe_ops_.size(); i++) {
out << "\n Operator: " << eoe_ops_[i]->id();
}
} else {
out << " None.";
}
out << "\n\n";
}
}
Status EpochCtrlOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
if (child_.empty()) {
RETURN_STATUS_UNEXPECTED("EpochCtrlOp can't be the leaf node.");
}
std::unique_ptr<DataBuffer> buf;
// `retry_if_eoe` is false because EpochCtrlOp does not eat EOE.
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, false));
// Only intercept EOE for EoeReceived processing, after that the EOE is forwarded to next op.
// Other databuffers containing data or EOF will simply be forwarded.
// EOF can simply be forwarded because this op does not spawn any thread, thus does not require clean up.
if (buf->eoe()) {
RETURN_IF_NOT_OK(EoeReceived(worker_id));
}
*p_buffer = std::move(buf);
return Status::OK();
}
Status EpochCtrlOp::EoeReceived(int32_t worker_id) {
repeat_count_++;
MS_LOG(DEBUG) << "Epoch Control operator received end of epoch. Epoch count is now: " << repeat_count_
<< ". Repeated: " << BitTest(op_ctrl_flags_, kDeOpRepeated) << ". Max epochs: " << max_repeats_;
// If we've reached the requested epoch count, then flag the leaf nodes
// to tell them they've got one more epoch to perform. When they reach the end
// of the last epoch, they quit rather than loop again.
if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) {
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "EpochCtrl setting last repeat for eoe_op: " << eoe_op->id();
eoe_op->set_control_flag(kDeOpLastRepeat);
}
}
// This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it.
state_ = OpState::kDeOpIdle;
if (repeat_count_ != max_repeats_) {
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id();
RETURN_IF_NOT_OK(eoe_op->Reset());
}
}
return Status::OK();
}
// Pre-Visitor accept method for NodePass
Status EpochCtrlOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->PreRunOnNode(shared_from_base<EpochCtrlOp>(), modified);
}
// Visitor accept method for NodePass
Status EpochCtrlOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->RunOnNode(shared_from_base<EpochCtrlOp>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_
#define DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
namespace mindspore {
namespace dataset {
class EpochCtrlOp : public RepeatOp {
public:
class Builder : public RepeatOp::Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @param count - The number of repeats to do
// @return This is a constructor.
explicit Builder(int32_t count) : RepeatOp::Builder(count) {}
// Default destructor
~Builder() = default;
// The builder "build" method creates the final object.
// @return shared_ptr to the new EpochCtrlOp object
Status Build(std::shared_ptr<EpochCtrlOp> *);
};
// Contructor
explicit EpochCtrlOp(int32_t num_epoch);
// Destructor
~EpochCtrlOp();
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since EpochCtrlOp is derived from RepeatOp which is an inlined op, getting a buffer from us
// will simply bounce you to get a buffer from our child.
// Epoch Control Op does not eat the EOE, it will pass the EOE to the next op.
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
// Base-class override for handling cases when an eoe is received.
// @param worker_id - The worker id
Status EoeReceived(int32_t worker_id) override;
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_

View File

@ -132,6 +132,7 @@ Status RepeatOp::EoeReceived(int32_t worker_id) {
// Invoke a reset against the eoe nodes only. // Invoke a reset against the eoe nodes only.
for (auto &eoe_op : eoe_ops_) { for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "Repeat operator sending reset to operator: " << eoe_op->id();
RETURN_IF_NOT_OK(eoe_op->Reset()); RETURN_IF_NOT_OK(eoe_op->Reset());
} }
@ -167,8 +168,9 @@ int32_t RepeatOp::num_consumers() const {
Status RepeatOp::Reset() { Status RepeatOp::Reset() {
// If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op. // If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op.
// In that case, we now have to bounce the reset down to our own eoe ops. // In that case, we now have to bounce the reset down to our own eoe ops.
MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") reset."; MS_LOG(DEBUG) << "Repeat operator " << operator_id_ << " got reset.";
for (auto &eoe_op : eoe_ops_) { for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "Nested repeat operator bouncing a reset to operator: " << eoe_op->id();
RETURN_IF_NOT_OK(eoe_op->Reset()); RETURN_IF_NOT_OK(eoe_op->Reset());
} }
state_ = OpState::kDeOpRunning; state_ = OpState::kDeOpRunning;

View File

@ -46,7 +46,7 @@ class RepeatOp : public PipelineOp {
// @return shared_ptr to the new RepeatOp object // @return shared_ptr to the new RepeatOp object
Status Build(std::shared_ptr<RepeatOp> *); Status Build(std::shared_ptr<RepeatOp> *);
private: protected:
int32_t build_max_repeats_; int32_t build_max_repeats_;
Status SanityCheck() const; Status SanityCheck() const;
@ -131,11 +131,11 @@ class RepeatOp : public PipelineOp {
// @return Name of the current Op // @return Name of the current Op
std::string Name() const override { return "RepeatOp"; } std::string Name() const override { return "RepeatOp"; }
/// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
/// \param[in] eoe_op The input leaf/eoe operator to add to the list // \param[in] eoe_op The input leaf/eoe operator to add to the list
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }
private: protected:
int32_t max_repeats_; // The number of repeats that the user requested int32_t max_repeats_; // The number of repeats that the user requested
int32_t repeat_count_; // A counter for the current number of executed repeats int32_t repeat_count_; // A counter for the current number of executed repeats
std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat. std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat.

View File

@ -132,8 +132,9 @@ Status ZipOp::prepare(TensorQTable *const table) {
if (eof_) { if (eof_) {
return Status::OK(); return Status::OK();
} }
// One of our child iterators encounter EOE. Returns and proceed with draining phase.
if (new_row.empty()) { if (new_row.empty()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase got empty row!"); return Status::OK();
} }
// Pack this first row into our tensor table // Pack this first row into our tensor table

View File

@ -23,6 +23,7 @@
#include "minddata/dataset/engine/opt/pre/removal_pass.h" #include "minddata/dataset/engine/opt/pre/removal_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/opt/post/repeat_pass.h" #include "minddata/dataset/engine/opt/post/repeat_pass.h"
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" #include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
#include "minddata/dataset/engine/perf/profiling.h" #include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/perf/monitor.h" #include "minddata/dataset/engine/perf/monitor.h"
@ -50,11 +51,11 @@ Status ExecutionTree::AssociateNode(const std::shared_ptr<DatasetOp> &op) {
if (op->tree_ == this) { if (op->tree_ == this) {
return Status::OK(); return Status::OK();
} }
if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) { if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding && tree_state_ != kDeTStatePrepare) {
std::string err_msg = std::string err_msg =
"Invalid tree state for adding a node. Current state: " + std::to_string(static_cast<int>(tree_state_)) + "Invalid tree state for adding a node. Current state: " + std::to_string(static_cast<int>(tree_state_)) +
" Expected states: " + std::to_string(static_cast<int>(kDeTStateInit)) + " or " + " Expected states: " + std::to_string(static_cast<int>(kDeTStateInit)) + " or " +
std::to_string(static_cast<int>(kDeTStateBuilding)); std::to_string(static_cast<int>(kDeTStateBuilding)) + " or " + std::to_string(static_cast<int>(kDeTStatePrepare));
RETURN_STATUS_UNEXPECTED(err_msg); RETURN_STATUS_UNEXPECTED(err_msg);
} }
@ -200,7 +201,9 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
// For example, repeatOp inlining // For example, repeatOp inlining
// //
// @return Status - The error code return // @return Status - The error code return
Status ExecutionTree::Prepare() { Status ExecutionTree::Prepare(int32_t num_epochs) {
num_epochs_ = num_epochs;
// Pre optimization compulsory transformation // Pre optimization compulsory transformation
RETURN_IF_NOT_OK(this->PrepareTreePreAction()); RETURN_IF_NOT_OK(this->PrepareTreePreAction());
@ -222,6 +225,7 @@ Status ExecutionTree::PrepareTreePreAction() {
std::vector<std::unique_ptr<Pass>> pre_actions; std::vector<std::unique_ptr<Pass>> pre_actions;
// Construct pre actions // Construct pre actions
MS_LOG(INFO) << "Running pre pass loops."; MS_LOG(INFO) << "Running pre pass loops.";
pre_actions.push_back(std::make_unique<InjectionPass>());
pre_actions.push_back(std::make_unique<RemovalPass>()); pre_actions.push_back(std::make_unique<RemovalPass>());
pre_actions.push_back(std::make_unique<CacheTransformPass>()); pre_actions.push_back(std::make_unique<CacheTransformPass>());
// Apply pre action passes // Apply pre action passes
@ -278,6 +282,11 @@ Status ExecutionTree::PrepareDeprecated() {
" Expected state: " + std::to_string(static_cast<int>(kDeTStatePrepare)); " Expected state: " + std::to_string(static_cast<int>(kDeTStatePrepare));
RETURN_STATUS_UNEXPECTED(err_msg); RETURN_STATUS_UNEXPECTED(err_msg);
} }
if (root_ == nullptr) {
RETURN_STATUS_UNEXPECTED("Please assign one operator as the root of this tree.");
}
// Start the recursive prepare // Start the recursive prepare
RETURN_IF_NOT_OK(this->PrepareNode(root_)); RETURN_IF_NOT_OK(this->PrepareNode(root_));
tree_state_ = kDeTStateReady; tree_state_ = kDeTStateReady;

View File

@ -176,7 +176,7 @@ class ExecutionTree {
// For example, repeatOp inlining // For example, repeatOp inlining
// //
// @return Status - The error code return // @return Status - The error code return
Status Prepare(); Status Prepare(int num_epochs = -1);
// Compulsory transformation/action pre optimization. // Compulsory transformation/action pre optimization.
// @return Status - The error code return // @return Status - The error code return
@ -193,6 +193,7 @@ class ExecutionTree {
// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
// walk the tree to perform modifications to the tree or specific nodes within the tree to get // walk the tree to perform modifications to the tree or specific nodes within the tree to get
// it ready for execution. // it ready for execution.
// @param Total number of epochs that will be run on this tree
// @return Status - The error code return // @return Status - The error code return
Status PrepareDeprecated(); Status PrepareDeprecated();
@ -231,6 +232,10 @@ class ExecutionTree {
// Optional optimizations status // Optional optimizations status
bool OptimizationEnabled() const { return optimize_; } bool OptimizationEnabled() const { return optimize_; }
// Getter function to get the total number of epochs to be run on this tree.
// @return total number of epochs
int32_t num_epochs() { return num_epochs_; }
private: private:
// A helper functions for doing the recursive printing // A helper functions for doing the recursive printing
// @param dataset_op - The dataset op to print // @param dataset_op - The dataset op to print
@ -245,6 +250,7 @@ class ExecutionTree {
int32_t id_count_; // Counter for generating operator id's int32_t id_count_; // Counter for generating operator id's
uint32_t prepare_flags_; // Flags used during tree prepare uint32_t prepare_flags_; // Flags used during tree prepare
TreeState tree_state_; // Tracking the current tree state TreeState tree_state_; // Tracking the current tree state
int32_t num_epochs_; // Total number of epochs to run for this tree
std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
bool optimize_; // Flag to enable optional optimizations bool optimize_; // Flag to enable optional optimizations

View File

@ -5,6 +5,7 @@ add_library(engine-opt OBJECT
post/repeat_pass.cc post/repeat_pass.cc
pre/cache_pass.cc pre/cache_pass.cc
pre/cache_transform_pass.cc pre/cache_transform_pass.cc
pre/injection_pass.cc
pre/removal_nodes.cc pre/removal_nodes.cc
pre/removal_pass.cc pre/removal_pass.cc
optional/tensor_op_fusion_pass.cc optional/tensor_op_fusion_pass.cc

View File

@ -16,11 +16,13 @@
#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/batch_op.h" #include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h" #include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h" #include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" #include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h" #include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h" #include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/map_op.h" #include "minddata/dataset/engine/datasetops/map_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h" #include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/rename_op.h" #include "minddata/dataset/engine/datasetops/rename_op.h"
@ -230,6 +232,11 @@ Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified)
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
} }
Status NodePass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Fallback to base class visitor by default // Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
@ -244,5 +251,15 @@ Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified
// Fallback to base class visitor by default // Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
} }
Status NodePass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -77,6 +77,10 @@ class CacheMergeOp;
class CacheLookupOp; class CacheLookupOp;
class EpochCtrlOp;
class BuildVocabOp;
// The base class Pass is the basic unit of tree transformation. // The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here. // The actual implementation of the passes will be derived from here.
class Pass : public std::enable_shared_from_this<Pass> { class Pass : public std::enable_shared_from_this<Pass> {
@ -190,12 +194,18 @@ class NodePass : public Pass {
virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified);
private: private:
// Helper function to perform DFS visit // Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified); Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);

View File

@ -20,6 +20,7 @@
#include "minddata/dataset/engine/datasetops/cache_op.h" #include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" #include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h" #include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -28,6 +29,9 @@ RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(fa
// Identifies the subtree below this node as being in a repeated path of the tree. // Identifies the subtree below this node as being in a repeated path of the tree.
Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Create a new stack for eoe operators and push onto our stack of stacks.
std::unique_ptr<eoe_op_stack> new_stack = std::make_unique<eoe_op_stack>();
eoe_op_stacks_.push(std::move(new_stack));
// If we are already repeated, then this is a nested repeat. // If we are already repeated, then this is a nested repeat.
if (is_repeated_) { if (is_repeated_) {
nested_repeats_++; nested_repeats_++;
@ -36,6 +40,18 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified)
return Status::OK(); return Status::OK();
} }
// Identifies the subtree below this node as being in a repeated path of the tree.
Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
// EpochCtrl is derived from RepeatOp. Generally it should do the identical setup
// that RepeatOp does. However, epoch control is actually simpler because it can
// only exist as the root node so it doesn't need all the nested code.
// Create a new stack for eoe operators and push onto our stack of stacks.
std::unique_ptr<eoe_op_stack> new_stack = std::make_unique<eoe_op_stack>();
eoe_op_stacks_.push(std::move(new_stack));
is_repeated_ = true;
return Status::OK();
}
// Identifies the subtree below this node as being in a cache merge path // Identifies the subtree below this node as being in a cache merge path
Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Turn on the flag that we're under a merge op // Turn on the flag that we're under a merge op
@ -47,13 +63,24 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modifi
Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
while (leaf_op != nullptr) { while (leaf_op != nullptr) {
node->AddToEoeList(leaf_op); node->AddToEoeList(leaf_op);
leaf_op = PopFromEOEOpStack(); leaf_op = PopFromEOEOpStack();
} }
// At this point, we are done with the save area stack. It's a unique pointer to an empty stack
// at this time, so we can pop it to get rid of it.
eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
if (!current_stack->empty()) {
RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!");
}
eoe_op_stacks_.pop();
// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
// and add it to the list of eoe/leaf ops for the repeat, removing it from the save area. // and add it to the list of eoe/leaf ops for the repeat. It is important that the op is removed
// from the save area, because the merge op above us may also take action on it later for a different
// case when there is no repeat in the merge leg.
if (is_merge_ && cache_lookup_) { if (is_merge_ && cache_lookup_) {
cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated); cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated);
node->AddToEoeList(std::move(cache_lookup_)); node->AddToEoeList(std::move(cache_lookup_));
@ -65,16 +92,29 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
node->set_control_flag(DatasetOp::kDeOpRepeated); node->set_control_flag(DatasetOp::kDeOpRepeated);
AddToEOEOpStack(node); AddToEOEOpStack(node);
nested_repeats_--; nested_repeats_--;
} } else {
// If we are not nested, or we were the top-most repeat, now we clear the flag
// If we are not nested, or we were the top-most repeat, now we clear the flag if (nested_repeats_ != 0) {
if (nested_repeats_ == 0) { RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!");
}
is_repeated_ = false; is_repeated_ = false;
} }
return Status::OK(); return Status::OK();
} }
// Hooks up any identified eoe nodes under this repeat.
Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
// Pop the leaf ops from the save-area stack and add them to the eoe node tracking
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
while (leaf_op != nullptr) {
node->AddToEoeList(leaf_op);
leaf_op = PopFromEOEOpStack();
}
is_repeated_ = false;
return Status::OK();
}
// CacheOp removes previous leaf ops and replaces them with itself // CacheOp removes previous leaf ops and replaces them with itself
Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
if (is_repeated_) { if (is_repeated_) {
@ -118,9 +158,16 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
// Turns off the tracking for operations under merge op // Turns off the tracking for operations under merge op
Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Setting the flag is needed since we didn't call the base class DatasetOp version // Setting the flag is needed since we didn't call the base class DatasetOp version
if (is_repeated_) node->set_control_flag(DatasetOp::kDeOpRepeated); if (is_repeated_) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
// If there was not any repeat in the merge cache miss leg, then the cache_lookup
// would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack
if (cache_lookup_) {
AddToEOEOpStack(std::move(cache_lookup_));
}
}
cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used
is_merge_ = false; is_merge_ = false;
cache_lookup_.reset(); // If a repeat op did not consume this then it's no longer needed
return Status::OK(); return Status::OK();
} }
@ -135,25 +182,32 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified
// In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here. // In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here.
if (is_repeated_) { if (is_repeated_) {
node->set_control_flag(DatasetOp::kDeOpRepeated); node->set_control_flag(DatasetOp::kDeOpRepeated);
AddToEOEOpStack(node); // Delay the assigment of this leap to the eoe stack and allow the merge op processing to handle that.
} else {
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
cache_lookup_ = std::static_pointer_cast<DatasetOp>(node);
} }
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
// Further, if there's a repeat above the merge but no repeat in the cache miss leg, then the merge op will
// add the lookup to the eoe stack
cache_lookup_ = std::static_pointer_cast<DatasetOp>(node);
return Status::OK(); return Status::OK();
} }
// Adds an operator to the eoe operator stack save area // Adds an operator to the eoe operator stack save area
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); } void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) {
eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
current_stack->push(dataset_op);
}
// Pops an operator from the eoe operator stack save area // Pops an operator from the eoe operator stack save area
std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() { std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() {
std::shared_ptr<DatasetOp> top_op = nullptr; std::shared_ptr<DatasetOp> top_op = nullptr;
if (!eoe_stack_.empty()) { eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
top_op = eoe_stack_.top(); if (current_stack != nullptr && !current_stack->empty()) {
eoe_stack_.pop(); top_op = current_stack->top();
current_stack->pop();
} }
return top_op; return top_op;
} }

View File

@ -30,6 +30,8 @@ namespace dataset {
/// to the eoe-producing (typically leaf) nodes underneath it. /// to the eoe-producing (typically leaf) nodes underneath it.
class RepeatPass : public NodePass { class RepeatPass : public NodePass {
public: public:
using eoe_op_stack = std::stack<std::shared_ptr<DatasetOp>>;
/// \brief Constructor /// \brief Constructor
RepeatPass(); RepeatPass();
@ -39,6 +41,12 @@ class RepeatPass : public NodePass {
/// \return Status The error code return /// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override;
/// \brief Identifies the subtree below this node as being in a cache merge path /// \brief Identifies the subtree below this node as being in a cache merge path
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
@ -51,6 +59,12 @@ class RepeatPass : public NodePass {
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
/// \brief Hooks up any identified eoe nodes under this repeat.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override;
/// \brief CacheOp removes previous leaf ops and replaces them with itself /// \brief CacheOp removes previous leaf ops and replaces them with itself
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
@ -86,11 +100,11 @@ class RepeatPass : public NodePass {
/// \return shared_ptr to the popped operator /// \return shared_ptr to the popped operator
std::shared_ptr<DatasetOp> PopFromEOEOpStack(); std::shared_ptr<DatasetOp> PopFromEOEOpStack();
bool is_repeated_; // T/F if we are processing under a repeat bool is_repeated_; // T/F if we are processing under a repeat
bool is_merge_; // T/F if we are processing under a cache merge op bool is_merge_; // T/F if we are processing under a cache merge op
int32_t nested_repeats_; // A counter for nested repeats int32_t nested_repeats_; // A counter for nested repeats
std::stack<std::shared_ptr<DatasetOp>> eoe_stack_; // A save area for leaf/eoe ops std::stack<std::unique_ptr<eoe_op_stack>> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting)
std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
namespace mindspore {
namespace dataset {
// constructor
InjectionPass::InjectionFinder::InjectionFinder(InjectionPass *injection_pass) : injection_pass_(injection_pass) {}
// Performs finder work for BuildVocabOp that has special rules about epoch control injection
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) {
if (injection_pass_) {
injection_pass_->epoch_ctrl_bypass_ = true;
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
}
}
// Temporary code to prevent the injection of epoch control when cache op is present
// Remove this code in cache op phase 2
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
if (injection_pass_) {
injection_pass_->epoch_ctrl_bypass_ = true;
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
}
}
// constructor
InjectionPass::InjectionPass() : epoch_ctrl_bypass_(false) {}
// Runs an injection pass to inject in operators needed at the pre pass stage
Status InjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) {
MS_LOG(INFO) << "Pre pass: Injection pass started.";
// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
// The finder can make updates to the InjectionPass object.
InjectionPass::InjectionFinder finder(this);
finder.Run(tree, modified);
// The first injection logic is to check if we should inject the epoch control op as the root node.
// Do not inject the op if the number of epochs is 1.
int32_t num_epochs = tree->num_epochs();
if (num_epochs != 1 && !epoch_ctrl_bypass_) {
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op;
RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op));
RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op));
std::shared_ptr<DatasetOp> node = tree->root();
if (std::dynamic_pointer_cast<DeviceQueueOp>(node) == nullptr) {
tree->root()->InsertAsParent(epoch_ctrl_op);
} else {
tree->root()->child(0)->InsertAsParent(epoch_ctrl_op);
}
}
MS_LOG(INFO) << "Pre pass: Injection pass complete.";
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,75 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#define DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#include <memory>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class DatasetOp;
/// \class InjectionPass injection_pass.h
/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
/// parsing.
class InjectionPass : public TreePass {
/// \class InjectionFinder
/// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for
/// operators that need to be injected. It is run first by the main injection pass to find out what operators
/// it may need to inject.
class InjectionFinder : public NodePass {
public:
/// \brief Constructor
explicit InjectionFinder(InjectionPass *injection_pass);
/// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) override;
/// \brief Temporary code to prevent the injection of epoch control when cache op is present.
/// Remove this code in cache op phase 2
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
private:
InjectionPass *injection_pass_;
};
public:
/// \brief Constructor
InjectionPass();
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
private:
bool epoch_ctrl_bypass_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_

View File

@ -29,20 +29,27 @@ std::shared_ptr<TdtPlugin> TdtPlugin::GetInstance() {
return instance_ptr_; return instance_ptr_;
} }
TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time) { TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time,
tdt::TdtDataType tdt_type) {
MS_LOG(DEBUG) << "TDT channel name is " << channel_name << "."; MS_LOG(DEBUG) << "TDT channel name is " << channel_name << ".";
std::vector<DataItem> items; std::vector<DataItem> items;
double start_time; double start_time;
auto ret = translate(ts_row, items); if (tdt_type == tdt::TDT_TENSOR) {
if (ret != SUCCESS) { auto ret = translate(ts_row, items);
MS_LOG(ERROR) << "TDT converting tensor failed!"; if (ret != SUCCESS) {
return FAILED; MS_LOG(ERROR) << "TDT converting tensor failed!";
return FAILED;
}
} else if (tdt_type == tdt::TDT_END_OF_SEQUENCE) {
DataItem data_item;
data_item.dataType_ = tdt::TDT_END_OF_SEQUENCE;
items.emplace_back(data_item);
MS_LOG(INFO) << "TDT data type is TDT_END_OF_SEQUENCE";
} }
if (profiling) { if (profiling) {
start_time = ProfilingTime::GetCurMilliSecond(); start_time = ProfilingTime::GetCurMilliSecond();
} }
if (tdt::TdtHostPushData(channel_name, items) != 0) { if (tdt::TdtHostPushData(channel_name, items) != 0) {
MS_LOG(ERROR) << "TDT pushing data failed!";
return FAILED; return FAILED;
} }
if (profiling) { if (profiling) {
@ -122,8 +129,8 @@ TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector<DataItem> &i
data_item.dataPtr_ = data_item.dataPtr_ =
std::shared_ptr<void>(reinterpret_cast<uchar *>(&(*ts->begin<uint8_t>())), [](const void *elem) {}); std::shared_ptr<void>(reinterpret_cast<uchar *>(&(*ts->begin<uint8_t>())), [](const void *elem) {});
items.emplace_back(data_item); items.emplace_back(data_item);
MS_LOG(DEBUG) << "TDT data type is " << datatype << ", data shape is " << dataShapes << ", data length is " MS_LOG(INFO) << "TDT data type is TDT_TENSOR, tensor type is " << datatype << ", tensor shape is " << dataShapes
<< ts->Size() << "."; << ", data length is " << ts->Size() << ".";
} }
return SUCCESS; return SUCCESS;
} }

View File

@ -38,7 +38,8 @@ class TdtPlugin {
public: public:
static std::shared_ptr<TdtPlugin> GetInstance(); static std::shared_ptr<TdtPlugin> GetInstance();
TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time); TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time,
tdt::TdtDataType tdt_type = tdt::TDT_TENSOR);
private: private:
TdtPlugin() {} TdtPlugin() {}

View File

@ -797,6 +797,9 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
(void)InitBackend(); (void)InitBackend();
} }
#endif #endif
if (iter_num == -1) {
iter_num = INT32_MAX;
}
if (name == kMsConvert || name == kMsVm) { if (name == kMsConvert || name == kMsVm) {
return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run); return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run);
} }

View File

@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32, check_save check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
try: try:
@ -946,14 +946,14 @@ class Dataset:
raise TypeError("apply_func must return a dataset.") raise TypeError("apply_func must return a dataset.")
return dataset return dataset
@check_positive_int32 def device_que(self, prefetch_size=None, send_epoch_end=True):
def device_que(self, prefetch_size=None):
""" """
Return a transferredDataset that transfer data through device. Return a transferredDataset that transfer data through device.
Args: Args:
prefetch_size (int, optional): prefetch number of records ahead of the prefetch_size (int, optional): prefetch number of records ahead of the
user's request (default=None). user's request (default=None).
send_epoch_end (bool, optional): whether send end of sequence to device or not.(default=True)
Note: Note:
If device is Ascend, features of data will be transferred one by one. The limitation If device is Ascend, features of data will be transferred one by one. The limitation
@ -962,15 +962,14 @@ class Dataset:
Return: Return:
TransferDataset, dataset for transferring. TransferDataset, dataset for transferring.
""" """
return self.to_device() return self.to_device(send_epoch_end=send_epoch_end)
@check_positive_int32 def to_device(self, send_epoch_end=True):
def to_device(self, num_batch=None):
""" """
Transfer data through CPU, GPU or Ascend devices. Transfer data through CPU, GPU or Ascend devices.
Args: Args:
num_batch (int, optional): limit the number of batch to be sent to device (default=None). send_epoch_end (bool, optional): whether send end of sequence to device or not.(default=True)
Note: Note:
If device is Ascend, features of data will be transferred one by one. The limitation If device is Ascend, features of data will be transferred one by one. The limitation
@ -982,19 +981,9 @@ class Dataset:
Raises: Raises:
TypeError: If device_type is empty. TypeError: If device_type is empty.
ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'. ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'.
ValueError: If num_batch is not positive or larger than int_max.
ValueError: If dataset size is None or 0.
RuntimeError: If dataset is unknown. RuntimeError: If dataset is unknown.
RuntimeError: If distribution file path is given but failed to read. RuntimeError: If distribution file path is given but failed to read.
""" """
if self.get_dataset_size() is None or 0:
raise ValueError("dataset size is None or 0.")
if num_batch is None:
num_batch = self.get_dataset_size()
repeat_count = self.get_repeat_count()
num_batch = num_batch * repeat_count
queue_name = str(uuid.uuid1()) queue_name = str(uuid.uuid1())
if context: if context:
@ -1008,9 +997,6 @@ class Dataset:
if device_type not in ('Ascend', 'GPU', 'CPU'): if device_type not in ('Ascend', 'GPU', 'CPU'):
raise ValueError("Only support CPU, Ascend, GPU") raise ValueError("Only support CPU, Ascend, GPU")
if num_batch == 0:
raise ValueError("num_batch is 0.")
def get_distribution(output_dataset): def get_distribution(output_dataset):
dev_id = 0 dev_id = 0
if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2, if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2,
@ -1032,7 +1018,7 @@ class Dataset:
distribution_path, device_id = get_distribution(self) distribution_path, device_id = get_distribution(self)
if distribution_path == "": if distribution_path == "":
return TransferDataset(self, queue_name, device_id, device_type, num_batch) return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end)
try: try:
with open(distribution_path, 'r') as distribution_f: with open(distribution_path, 'r') as distribution_f:
dist = json.load(distribution_f) dist = json.load(distribution_f)
@ -1042,7 +1028,7 @@ class Dataset:
except Exception: except Exception:
raise RuntimeError("Distribution file failed to read") raise RuntimeError("Distribution file failed to read")
return TransferDataset(self, queue_name, device_id, device_type, num_batch) return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end)
@check_save @check_save
def save(self, file_name, num_files=1, file_type='mindrecord'): def save(self, file_name, num_files=1, file_type='mindrecord'):
@ -1072,7 +1058,7 @@ class Dataset:
return SaveOp(self).save(file_names, file_type) return SaveOp(self).save(file_names, file_type)
def create_tuple_iterator(self, columns=None): def create_tuple_iterator(self, columns=None, num_epochs=-1):
""" """
Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data. Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data.
@ -1098,9 +1084,9 @@ class Dataset:
""" """
if self._noop_mode(): if self._noop_mode():
return DummyIterator(self, 'tuple') return DummyIterator(self, 'tuple')
return TupleIterator(self, columns) return TupleIterator(self, columns, num_epochs)
def create_dict_iterator(self): def create_dict_iterator(self, num_epochs=-1):
""" """
Create an Iterator over the dataset. Create an Iterator over the dataset.
@ -1123,7 +1109,7 @@ class Dataset:
""" """
if self._noop_mode(): if self._noop_mode():
return DummyIterator(self, 'dict') return DummyIterator(self, 'dict')
return DictIterator(self) return DictIterator(self, num_epochs)
def __iter__(self): def __iter__(self):
"""Create an Iterator over the dataset.""" """Create an Iterator over the dataset."""
@ -1149,7 +1135,7 @@ class Dataset:
self._batch_size = device_iter.get_batch_size() self._batch_size = device_iter.get_batch_size()
self._num_classes = device_iter.num_classes() self._num_classes = device_iter.num_classes()
self._repeat_count = device_iter.get_repeat_count() self._repeat_count = device_iter.get_repeat_count()
device_iter.release() device_iter.stop()
def output_shapes(self): def output_shapes(self):
""" """
@ -2085,7 +2071,7 @@ class RepeatDataset(DatasetOp):
""" """
child_size = self.children[0].get_dataset_size() child_size = self.children[0].get_dataset_size()
if child_size is not None: if child_size is not None:
return child_size return child_size * self.count
return None return None
def get_repeat_count(self): def get_repeat_count(self):
@ -2097,7 +2083,6 @@ class RepeatDataset(DatasetOp):
""" """
return self.count return self.count
class SkipDataset(DatasetOp): class SkipDataset(DatasetOp):
""" """
The result of applying Skip operator to the input Dataset. The result of applying Skip operator to the input Dataset.
@ -2317,10 +2302,10 @@ class TransferDataset(DatasetOp):
queue_name (str): Name of device queue. queue_name (str): Name of device queue.
device_id (int): Id of device. device_id (int): Id of device.
device_type (str): Type of device, including "CPU", "GPU", and "Ascend". device_type (str): Type of device, including "CPU", "GPU", and "Ascend".
num_batch (int): limit the number of batch to be sent to device (default=None). send_epoch_end (bool, optional): Whether send end of sequence to device or not.(default=True)
""" """
def __init__(self, input_dataset, queue_name, device_id, device_type, num_batch=None): def __init__(self, input_dataset, queue_name, device_id, device_type, send_epoch_end=True):
super().__init__() super().__init__()
self.children.append(input_dataset) self.children.append(input_dataset)
input_dataset.parent.append(self) input_dataset.parent.append(self)
@ -2328,7 +2313,7 @@ class TransferDataset(DatasetOp):
self._input_indexs = input_dataset.input_indexs self._input_indexs = input_dataset.input_indexs
self._device_type = device_type self._device_type = device_type
self._device_id = device_id self._device_id = device_id
self.__num_batch = num_batch self._send_epoch_end = send_epoch_end
self.iterator = None self.iterator = None
def get_args(self): def get_args(self):
@ -2336,13 +2321,13 @@ class TransferDataset(DatasetOp):
args["queue_name"] = self.queue_name args["queue_name"] = self.queue_name
args["device_type"] = self._device_type args["device_type"] = self._device_type
args["device_id"] = self._device_id args["device_id"] = self._device_id
args["num_batch"] = self.__num_batch args["send_epoch_end"] = self._send_epoch_end
return args return args
def create_dict_iterator(self): def create_dict_iterator(self, num_epochs=-1):
raise RuntimeError("TransferDataset is not iterable") raise RuntimeError("TransferDataset is not iterable")
def create_tuple_iterator(self, columns=None): def create_tuple_iterator(self, columns=None, num_epochs=-1):
raise RuntimeError("TransferDataset is not iterable") raise RuntimeError("TransferDataset is not iterable")
def __iter__(self): def __iter__(self):
@ -2354,12 +2339,14 @@ class TransferDataset(DatasetOp):
def output_types(self): def output_types(self):
raise RuntimeError("TransferDataset does not support output_types") raise RuntimeError("TransferDataset does not support output_types")
def send(self): def send(self, num_epochs=-1):
# need to keep iterator alive so the executionTree is not destroyed # need to keep iterator alive so the executionTree is not destroyed
if self._noop_mode(): if self._noop_mode():
return return
self.iterator = TupleIterator(self) self.iterator = TupleIterator(self, num_epochs=-1)
def stop_send(self):
self.iterator.depipeline.StopSend()
class RangeDataset(MappableDataset): class RangeDataset(MappableDataset):
""" """

View File

@ -29,7 +29,6 @@ from . import datasets as de
ITERATORS_LIST = list() ITERATORS_LIST = list()
def _cleanup(): def _cleanup():
"""Release all the Iterator.""" """Release all the Iterator."""
for itr_ref in ITERATORS_LIST: for itr_ref in ITERATORS_LIST:
@ -60,7 +59,6 @@ def _alter_node(node):
node.iterator_bootstrap() node.iterator_bootstrap()
return node return node
class Iterator: class Iterator:
""" """
General Iterator over a dataset. General Iterator over a dataset.
@ -69,10 +67,21 @@ class Iterator:
dataset: Dataset to be iterated over dataset: Dataset to be iterated over
""" """
def __init__(self, dataset): def __init__(self, dataset, num_epochs=-1):
self.num_epochs = num_epochs
ITERATORS_LIST.append(weakref.ref(self)) ITERATORS_LIST.append(weakref.ref(self))
# create a copy of tree and work on it. # create a copy of tree and work on it.
self.dataset = copy.deepcopy(dataset) self.dataset = copy.deepcopy(dataset)
self.parent_subtree = []
# The dataset passed into the iterator is not the root of the tree.
# Trim the tree by saving the parent subtree into self.parent_subtree and
# restore it after launching our c++ pipeline.
if self.dataset.parent:
logger.warning("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.")
self.parent_subtree = self.dataset.parent
self.dataset.parent = []
self.dataset = alter_tree(self.dataset) self.dataset = alter_tree(self.dataset)
if not self.__is_tree(): if not self.__is_tree():
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
@ -83,9 +92,17 @@ class Iterator:
root = self.__convert_node_postorder(self.dataset) root = self.__convert_node_postorder(self.dataset)
self.depipeline.AssignRootNode(root) self.depipeline.AssignRootNode(root)
self.depipeline.LaunchTreeExec() self.depipeline.LaunchTreeExec(self.num_epochs)
self._index = 0 self._index = 0
def stop(self):
"""
Manually terminate python iterator instead of relying on out of scope destruction.
"""
logger.info("terminating python iterator. This will also terminate c++ pipeline.")
if hasattr(self, 'depipeline') and self.depipeline:
del self.depipeline
def __is_tree_node(self, node): def __is_tree_node(self, node):
"""Check if a node is tree node.""" """Check if a node is tree node."""
if not node.children: if not node.children:
@ -214,9 +231,14 @@ class Iterator:
@abstractmethod @abstractmethod
def get_next(self): def get_next(self):
pass raise RuntimeError("Calling base class Iterator's get_next is invalid.")
def __next__(self): def __next__(self):
if not self.depipeline:
logger.warning("Iterator does not have a running c++ pipeline." +
"It can be because Iterator stop() had been called, or c++ pipeline crashed silently.")
raise RuntimeError("Iterator does not have a running c++ pipeline.")
data = self.get_next() data = self.get_next()
if not data: if not data:
if self._index == 0: if self._index == 0:
@ -293,12 +315,12 @@ class TupleIterator(Iterator):
def check_node_type(self, node): def check_node_type(self, node):
pass pass
def __init__(self, dataset, columns=None): def __init__(self, dataset, columns=None, num_epochs=-1):
if columns is not None: if columns is not None:
if not isinstance(columns, list): if not isinstance(columns, list):
columns = [columns] columns = [columns]
dataset = dataset.project(columns) dataset = dataset.project(columns)
super().__init__(dataset) super().__init__(dataset, num_epochs)
def __iter__(self): def __iter__(self):
return self return self

View File

@ -57,7 +57,8 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
# transform data format # transform data format
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
exec_dataset = exec_dataset.device_que() send_epoch_end = bool(dataset_size == -1)
exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end)
_executor.init_dataset(exec_dataset.queue_name, _executor.init_dataset(exec_dataset.queue_name,
dataset_size, dataset_size,
@ -126,7 +127,7 @@ def _construct_tensor_list(types, shapes, batch_expand_num=1):
def _to_tensor(elem, scaling_sens=None): def _to_tensor(elem, scaling_sens=None):
"""Conver numpy to tensor, adapt to minddata feed solution.""" """Convert numpy to tensor, adapt to feed the data from host solution."""
lst = [] lst = []
if not isinstance(elem, (tuple, list)): if not isinstance(elem, (tuple, list)):
elem = [elem] elem = [elem]
@ -145,7 +146,8 @@ def _to_tensor(elem, scaling_sens=None):
def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None): def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
"""Conver numpy to tensor, expanding batch dimension according to device_num, adapt to minddata feed solution.""" """Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
from host solution."""
lst = [] lst = []
if not isinstance(elem, (tuple, list)): if not isinstance(elem, (tuple, list)):
elem = [elem] elem = [elem]

View File

@ -16,7 +16,7 @@
import math import math
import os import os
from mindspore._checkparam import check_bool from mindspore._checkparam import check_bool, check_int
from .. import context from .. import context
from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \ from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
_construct_tensor_list, _to_full_shapes, _to_full_tensor _construct_tensor_list, _to_full_shapes, _to_full_tensor
@ -42,17 +42,23 @@ class DatasetHelper:
The iter of DatasetHelper will give one epoch data. The iter of DatasetHelper will give one epoch data.
Args: Args:
dataset (DataSet): The dataset. dataset (DataSet): The training dataset iterator.
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True.
Default: True. sink_size (int): Control the amount of data each sink.
If sink_size=-1, sink the complete dataset each epoch.
If sink_size>0, sink sink_size data each epoch. Default: -1.
Examples: Examples:
>>> dataset_helper = DatasetHelper(dataset) >>> dataset_helper = DatasetHelper(dataset)
>>> for inputs in dataset_helper: >>> for inputs in dataset_helper:
>>> outputs = network(*inputs) >>> outputs = network(*inputs)
""" """
def __init__(self, dataset, dataset_sink_mode=True):
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1):
check_bool(dataset_sink_mode) check_bool(dataset_sink_mode)
check_int(sink_size)
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
if dataset_sink_mode: if dataset_sink_mode:
if context.get_context("enable_ge"): if context.get_context("enable_ge"):
@ -68,9 +74,10 @@ class DatasetHelper:
iterclass = _DatasetIterMS iterclass = _DatasetIterMS
elif context.get_context("device_target") == "CPU": elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.") raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
self.iter = iterclass(dataset, sink_size)
else: else:
iterclass = _DatasetIterFeed iterclass = _DatasetIterNormal
self.iter = iterclass(dataset) self.iter = iterclass(dataset)
def __iter__(self): def __iter__(self):
return self.iter.__iter__() return self.iter.__iter__()
@ -80,21 +87,26 @@ class DatasetHelper:
"""Get the types and shapes from dataset on current config.""" """Get the types and shapes from dataset on current config."""
return self.iter.types_shapes() return self.iter.types_shapes()
def loop_size(self): def sink_size(self):
"""Get loop_size for every iteration.""" """Get sink_size for every iteration."""
return self.iter.loop_size return self.iter.get_sink_size()
def stop_send(self):
"""Free up resources about data sink."""
self.iter.stop_send()
class _DatasetIter: class _DatasetIter:
"""Base iter for dataset help""" """Base iter for dataset helper"""
def __init__(self, dataset): def __init__(self, dataset, sink_size):
if not hasattr(dataset, '__loop_size__'): self.dataset = dataset
self.loop_size = dataset.get_dataset_size() self.sink_size = sink_size
else: self.sink_count = 1
self.loop_size = dataset.__loop_size__
if not hasattr(dataset, '__ME_INITED__'): if not hasattr(dataset, '__TRANSFER_DATASET__'):
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) if hasattr(dataset, '__loop_size__'):
self.sink_size = dataset.__loop_size__
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'): if not hasattr(dataset, '__no_send__'):
@ -102,43 +114,70 @@ class _DatasetIter:
else: else:
_send_data(dataset) _send_data(dataset)
self.ind = 0 self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
self.dataset = dataset self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
def __iter__(self): def __iter__(self):
self.ind = 0 self.index = 0
return self return self
def __next__(self): def __next__(self):
if self.ind >= self.loop_count: if self.index >= self.sink_count:
raise StopIteration() raise StopIteration()
self.ind += 1 self.index += 1
return self.op() return self.op()
def types_shapes(self): def types_shapes(self):
return self.dataset_types, self.dataset_shapes return self.dataset_types, self.dataset_shapes
def get_loop_count(self, dataset): def get_sink_count(self, dataset):
loop_count = 1 sink_count = 1
if hasattr(dataset, '__loop_size__'): if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__ loop_size = dataset.__loop_size__
if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0: if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and ' raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
f'loop_size {loop_size} are not matched.') f'sink_size {loop_size} are not matched.')
loop_count = math.ceil(dataset.get_dataset_size() / loop_size) sink_count = math.ceil(dataset.get_dataset_size() / loop_size)
return loop_count return sink_count
def get_sink_size(self):
"""get sink_size to device"""
sink_size = 1
if hasattr(self.dataset, '__loop_size__'):
sink_size = self.dataset.__loop_size__
else:
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend":
if self.sink_size > 0:
sink_size = self.sink_size
else:
sink_size = self.dataset.get_dataset_size()
return sink_size
class _DatasetIterGE(_DatasetIter):
"""Iter for GE."""
def __init__(self, dataset, sink_size):
super().__init__(dataset, sink_size)
self.sink_count = self.get_sink_count(dataset)
batch_expand_num = 1
if _need_to_full():
batch_expand_num = _get_device_num()
tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num)
def op():
return tensor_list_run
self.op = op
class _DatasetIterMSLoopSink(_DatasetIter): class _DatasetIterMSLoopSink(_DatasetIter):
"""Iter for context (device_target=Ascend)""" """Iter for context (device_target=Ascend)"""
def __init__(self, dataset): def __init__(self, dataset, sink_size):
super(_DatasetIterMSLoopSink, self).__init__(dataset) super().__init__(dataset, sink_size)
self.loop_count = self.get_loop_count(dataset) self.sink_count = self.get_sink_count(dataset)
ms_role = os.getenv("MS_ROLE") ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"): if ms_role in ("MS_PSERVER", "MS_SCHED"):
self.loop_count = 1 self.sink_count = 1
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch, # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink. # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
@ -153,66 +192,42 @@ class _DatasetIterMSLoopSink(_DatasetIter):
class _DatasetIterMS(_DatasetIter): class _DatasetIterMS(_DatasetIter):
"""Iter for context (device_target=GPU)""" """Iter for MS(enable_loop_sink=False)."""
def __init__(self, dataset): def __init__(self, dataset, sink_size):
super(_DatasetIterMS, self).__init__(dataset) super().__init__(dataset, sink_size)
self.loop_count = dataset.get_dataset_size() if sink_size > 0:
self.loop_size = 1 self.sink_count = sink_size
else:
self.sink_count = dataset.get_dataset_size()
queue_name = dataset.__ME_INITED__ queue_name = dataset.__ME_INITED__
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
class _DatasetIterPSLite(_DatasetIter): class _DatasetIterPSLite(_DatasetIter):
"""Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED""" """Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED"""
def __init__(self, dataset): def __init__(self, dataset, sink_size):
super(_DatasetIterPSLite, self).__init__(dataset) super().__init__(dataset, sink_size)
self.loop_count = 1 self.sink_count = 1
self.loop_size = 1 self.sink_size = 1
self.op = None self.op = None
def op(): def op():
return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1) return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1)
self.op = op self.op = op
class _DatasetIterGE(_DatasetIter): class _DatasetIterNormal:
"""Iter for ge"""
def __init__(self, dataset):
super(_DatasetIterGE, self).__init__(dataset)
self.loop_count = self.get_loop_count(dataset)
batch_expand_num = 1
if _need_to_full():
batch_expand_num = _get_device_num()
tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num)
def op():
return tensor_list_run
self.op = op
class _DatasetIterFeed:
"""Iter for normal(non sink) mode, feed the data from host.""" """Iter for normal(non sink) mode, feed the data from host."""
def __init__(self, dataset): def __init__(self, dataset):
self.dataset = dataset self.dataset = dataset
self.device_num = _get_device_num() self.device_num = _get_device_num()
self.global_rank = _get_global_rank() self.global_rank = _get_global_rank()
self.repeat_count = dataset.get_repeat_count()
self.repeat_ind = 0
self.loop_count = dataset.get_dataset_size()
self.ind = 0
def __iter__(self): def __iter__(self):
if self.repeat_ind % self.repeat_count == 0: self.iter = self.dataset.create_tuple_iterator()
self.iter = self.dataset.__iter__()
self.repeat_ind += 1
self.ind = 0
return self return self
def __next__(self): def __next__(self):
if self.ind >= self.loop_count:
raise StopIteration()
self.ind += 1
data = self.iter.__next__() data = self.iter.__next__()
if _need_to_full(): if _need_to_full():
return _to_full_tensor(data, self.device_num, self.global_rank) return _to_full_tensor(data, self.device_num, self.global_rank)

View File

@ -21,7 +21,7 @@ import numpy as np
from mindspore import log as logger from mindspore import log as logger
from ..common.tensor import Tensor from ..common.tensor import Tensor
from ..nn.metrics import get_metrics from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int
from .callback import _InternalCallbackParam, RunContext, _CallbackManager from .callback import _InternalCallbackParam, RunContext, _CallbackManager
from .. import context from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
@ -225,7 +225,7 @@ class Model:
scaling_sens /= self._device_number scaling_sens /= self._device_number
return scaling_sens return scaling_sens
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode): def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1):
"""Initializes dataset.""" """Initializes dataset."""
need_wrap = False need_wrap = False
if dataset_sink_mode: if dataset_sink_mode:
@ -237,7 +237,7 @@ class Model:
if not is_train: if not is_train:
dataset.__loop_size__ = 1 dataset.__loop_size__ = 1
dataset_helper = DatasetHelper(dataset, dataset_sink_mode) dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size)
# remove later to deal with loop sink # remove later to deal with loop sink
if need_wrap: if need_wrap:
@ -317,7 +317,7 @@ class Model:
self._eval_network.compile(*inputs) self._eval_network.compile(*inputs)
break break
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
""" """
Training. Training.
@ -332,6 +332,7 @@ class Model:
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Configure pynative mode, the training process will be performed with Configure pynative mode, the training process will be performed with
dataset not sink. dataset not sink.
sink_size (int): Control the amount of data each sink. Default: -1.
""" """
epoch = check_int_positive(epoch) epoch = check_int_positive(epoch)
self._train_network.set_train() self._train_network.set_train()
@ -342,7 +343,10 @@ class Model:
cb_params = _InternalCallbackParam() cb_params = _InternalCallbackParam()
cb_params.train_network = self._train_network cb_params.train_network = self._train_network
cb_params.epoch_num = epoch cb_params.epoch_num = epoch
cb_params.batch_num = train_dataset.get_dataset_size() if dataset_sink_mode and sink_size > 0:
cb_params.batch_num = sink_size
else:
cb_params.batch_num = train_dataset.get_dataset_size()
cb_params.mode = "train" cb_params.mode = "train"
cb_params.loss_fn = self._loss_fn cb_params.loss_fn = self._loss_fn
cb_params.optimizer = self._optimizer cb_params.optimizer = self._optimizer
@ -364,7 +368,7 @@ class Model:
"So the training process will be performed with dataset not sink.") "So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params) self._train_process(epoch, train_dataset, list_callback, cb_params)
else: else:
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size)
@staticmethod @staticmethod
def _transform_callbacks(callbacks): def _transform_callbacks(callbacks):
@ -377,7 +381,7 @@ class Model:
return [callbacks] return [callbacks]
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
""" """
Training process. The data would be passed to network through dataset channel. Training process. The data would be passed to network through dataset channel.
@ -390,17 +394,18 @@ class Model:
function respectively. function respectively.
list_callback (Callback): Executor of callback list. Default: None. list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
sink_size (int): Control the amount of data each sink. Default: -1.
""" """
dataset_helper, train_network = self._exec_preprocess(self._train_network, dataset_helper, train_network = self._exec_preprocess(self._train_network,
is_train=True, is_train=True,
phase='train', phase='train',
dataset=train_dataset, dataset=train_dataset,
dataset_sink_mode=True) dataset_sink_mode=True,
sink_size=sink_size)
self._train_network = train_network self._train_network = train_network
cb_params.train_network = self._train_network cb_params.train_network = self._train_network
cb_params.cur_step_num = 0 cb_params.cur_step_num = 0
loop_size = dataset_helper.loop_size()
run_context = RunContext(cb_params) run_context = RunContext(cb_params)
list_callback.begin(run_context) list_callback.begin(run_context)
@ -412,9 +417,9 @@ class Model:
# for data sink dataset_helper only iter once, other wise iter epoch_size times. # for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper: for inputs in dataset_helper:
cb_params.cur_step_num += loop_size
list_callback.step_begin(run_context) list_callback.step_begin(run_context)
outputs = self._train_network(*inputs) outputs = self._train_network(*inputs)
cb_params.cur_step_num += dataset_helper.sink_size()
cb_params.net_outputs = outputs cb_params.net_outputs = outputs
list_callback.step_end(run_context) list_callback.step_end(run_context)
@ -422,6 +427,7 @@ class Model:
should_stop = should_stop or run_context.get_stop_requested() should_stop = should_stop or run_context.get_stop_requested()
if should_stop: if should_stop:
break break
dataset_helper.stop_send()
list_callback.end(run_context) list_callback.end(run_context)
@ -490,7 +496,7 @@ class Model:
list_callback.end(run_context) list_callback.end(run_context)
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
""" """
Training API where the iteration is controlled by python front-end. Training API where the iteration is controlled by python front-end.
@ -515,7 +521,10 @@ class Model:
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Configure pynative mode, the training process will be performed with Configure pynative mode, the training process will be performed with
dataset not sink. dataset not sink.
sink_size (int): Control the amount of data each sink.
If sink_size=-1, sink the complete dataset each epoch.
If sink_size>0, sink sink_size data each epoch.
If dataset_sink_mode is False, set sink_size invalid. Default: -1.
Examples: Examples:
>>> dataset = get_dataset() >>> dataset = get_dataset()
@ -526,17 +535,19 @@ class Model:
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset) >>> model.train(2, dataset)
""" """
repeat_count = train_dataset.get_repeat_count()
if epoch != repeat_count and dataset_sink_mode is True:
logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}")
check_bool(dataset_sink_mode) check_bool(dataset_sink_mode)
check_int(sink_size)
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
_device_number_check(self._parallel_mode, self._device_number) _device_number_check(self._parallel_mode, self._device_number)
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
self._train(epoch, self._train(epoch,
train_dataset, train_dataset,
callbacks=callbacks, callbacks=callbacks,
dataset_sink_mode=dataset_sink_mode) dataset_sink_mode=dataset_sink_mode,
sink_size=sink_size)
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None): def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
""" """

View File

@ -43,7 +43,7 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, cfg.epoch_size) ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, 1)
network = AlexNet(cfg.num_classes) network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size())) lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size()))

View File

@ -57,7 +57,7 @@ if __name__ == '__main__':
ds_train = create_dataset(args_opt.dataset_path, ds_train = create_dataset(args_opt.dataset_path,
train_mode=True, train_mode=True,
epochs=train_config.train_epochs, epochs=1,
batch_size=train_config.batch_size, batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format), data_type=DataType(data_config.data_format),
rank_size=rank_size, rank_size=rank_size,
@ -82,7 +82,7 @@ if __name__ == '__main__':
if args_opt.do_eval: if args_opt.do_eval:
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False, ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
epochs=train_config.train_epochs, epochs=1,
batch_size=train_config.batch_size, batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format)) data_type=DataType(data_config.data_format))
eval_callback = EvalCallBack(model, ds_eval, auc_metric, eval_callback = EvalCallBack(model, ds_eval, auc_metric,

View File

@ -66,7 +66,7 @@ if __name__ == "__main__":
init() init()
args_opt.base_size = config.crop_size args_opt.base_size = config.crop_size
args_opt.crop_size = config.crop_size args_opt.crop_size = config.crop_size
train_dataset = create_dataset(args_opt, args_opt.data_url, config.epoch_size, config.batch_size, usage="train") train_dataset = create_dataset(args_opt, args_opt.data_url, 1, config.batch_size, usage="train")
dataset_size = train_dataset.get_dataset_size() dataset_size = train_dataset.get_dataset_size()
time_cb = TimeMonitor(data_size=dataset_size) time_cb = TimeMonitor(data_size=dataset_size)
callback = [time_cb, LossCallBack()] callback = [time_cb, LossCallBack()]

View File

@ -94,7 +94,7 @@ if __name__ == '__main__':
loss_scale = float(config.loss_scale) loss_scale = float(config.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as FasterRcnn.mindrecord0. # When create MindDataset, using the fitst mindrecord file, such as FasterRcnn.mindrecord0.
dataset = create_fasterrcnn_dataset(mindrecord_file, repeat_num=config.epoch_size, dataset = create_fasterrcnn_dataset(mindrecord_file, repeat_num=1,
batch_size=config.batch_size, device_num=device_num, rank_id=rank) batch_size=config.batch_size, device_num=device_num, rank_id=rank)
dataset_size = dataset.get_dataset_size() dataset_size = dataset.get_dataset_size()

View File

@ -78,7 +78,7 @@ if __name__ == '__main__':
mirror_mean=True) mirror_mean=True)
init() init()
dataset = create_dataset(cfg.data_path, cfg.epoch_size) dataset = create_dataset(cfg.data_path, 1)
batch_num = dataset.get_dataset_size() batch_num = dataset.get_dataset_size()
net = GoogleNet(num_classes=cfg.num_classes) net = GoogleNet(num_classes=cfg.num_classes)

View File

@ -45,8 +45,7 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset(os.path.join(args.data_path, "train"), ds_train = create_dataset(os.path.join(args.data_path, "train"),
cfg.batch_size, cfg.batch_size)
cfg.epoch_size)
network = LeNet5(cfg.num_classes) network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")

View File

@ -44,7 +44,7 @@ args = parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size) ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1)
step_size = ds_train.get_dataset_size() step_size = ds_train.get_dataset_size()
# define fusion network # define fusion network

View File

@ -77,7 +77,7 @@ if __name__ == '__main__':
model = Model(network, loss, opt, {'acc': Accuracy()}) model = Model(network, loss, opt, {'acc': Accuracy()})
print("============== Starting Training ==============") print("============== Starting Training ==============")
ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs) ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, 1)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max) keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck)

View File

@ -249,7 +249,7 @@ def train_parallel(config: TransformerConfig):
pre_train_dataset = load_dataset( pre_train_dataset = load_dataset(
data_files=config.pre_train_dataset, data_files=config.pre_train_dataset,
batch_size=config.batch_size, epoch_count=config.epochs, batch_size=config.batch_size, epoch_count=1,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step, sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(), rank_size=MultiAscend.get_group_size(),
@ -257,7 +257,7 @@ def train_parallel(config: TransformerConfig):
) if config.pre_train_dataset else None ) if config.pre_train_dataset else None
fine_tune_dataset = load_dataset( fine_tune_dataset = load_dataset(
data_files=config.fine_tune_dataset, data_files=config.fine_tune_dataset,
batch_size=config.batch_size, epoch_count=config.epochs, batch_size=config.batch_size, epoch_count=1,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step, sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(), rank_size=MultiAscend.get_group_size(),
@ -265,7 +265,7 @@ def train_parallel(config: TransformerConfig):
) if config.fine_tune_dataset else None ) if config.fine_tune_dataset else None
test_dataset = load_dataset( test_dataset = load_dataset(
data_files=config.test_dataset, data_files=config.test_dataset,
batch_size=config.batch_size, epoch_count=config.epochs, batch_size=config.batch_size, epoch_count=1,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step, sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(), rank_size=MultiAscend.get_group_size(),
@ -288,17 +288,17 @@ def train_single(config: TransformerConfig):
print(" | Starting training on single device.") print(" | Starting training on single device.")
pre_train_dataset = load_dataset(data_files=config.pre_train_dataset, pre_train_dataset = load_dataset(data_files=config.pre_train_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
epoch_count=config.epochs, epoch_count=1,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.pre_train_dataset else None sink_step=config.dataset_sink_step) if config.pre_train_dataset else None
fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset, fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
epoch_count=config.epochs, epoch_count=1,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.fine_tune_dataset else None sink_step=config.dataset_sink_step) if config.fine_tune_dataset else None
test_dataset = load_dataset(data_files=config.test_dataset, test_dataset = load_dataset(data_files=config.test_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
epoch_count=config.epochs, epoch_count=1,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.test_dataset else None sink_step=config.dataset_sink_step) if config.test_dataset else None

View File

@ -180,7 +180,7 @@ if __name__ == '__main__':
do_train=True, do_train=True,
config=config_gpu, config=config_gpu,
platform=args_opt.platform, platform=args_opt.platform,
repeat_num=epoch_size, repeat_num=1,
batch_size=config_gpu.batch_size) batch_size=config_gpu.batch_size)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
# resume # resume
@ -239,7 +239,7 @@ if __name__ == '__main__':
do_train=True, do_train=True,
config=config_ascend, config=config_ascend,
platform=args_opt.platform, platform=args_opt.platform,
repeat_num=epoch_size, repeat_num=1,
batch_size=config_ascend.batch_size) batch_size=config_ascend.batch_size)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
if args_opt.pre_trained: if args_opt.pre_trained:

View File

@ -86,7 +86,7 @@ if __name__ == '__main__':
do_train=True, do_train=True,
config=config, config=config,
device_target=args_opt.device_target, device_target=args_opt.device_target,
repeat_num=epoch_size, repeat_num=1,
batch_size=config.batch_size) batch_size=config.batch_size)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
# load pre trained ckpt # load pre trained ckpt

View File

@ -181,7 +181,7 @@ if __name__ == '__main__':
do_train=True, do_train=True,
config=config_gpu, config=config_gpu,
platform=args_opt.platform, platform=args_opt.platform,
repeat_num=epoch_size, repeat_num=1,
batch_size=config_gpu.batch_size) batch_size=config_gpu.batch_size)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
# resume # resume
@ -240,7 +240,7 @@ if __name__ == '__main__':
do_train=True, do_train=True,
config=config_ascend, config=config_ascend,
platform=args_opt.platform, platform=args_opt.platform,
repeat_num=epoch_size, repeat_num=1,
batch_size=config_ascend.batch_size) batch_size=config_ascend.batch_size)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
if args_opt.pre_trained: if args_opt.pre_trained:

View File

@ -36,12 +36,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
_cur_dir = os.getcwd() _cur_dir = os.getcwd()
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
""" do train """ """ do train """
if load_checkpoint_path == "": if load_checkpoint_path == "":
raise ValueError("Pretrain model missed, finetune task must load pretrain model!") raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size() steps_per_epoch = dataset.get_dataset_size()
epoch_num = dataset.get_repeat_count()
# optimizer # optimizer
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
@ -176,11 +175,11 @@ def run_classifier():
assessment_method=assessment_method) assessment_method=assessment_method)
if args_opt.do_train.lower() == "true": if args_opt.do_train.lower() == "true":
ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
assessment_method=assessment_method, assessment_method=assessment_method,
data_file_path=args_opt.train_data_file_path, data_file_path=args_opt.train_data_file_path,
schema_file_path=args_opt.schema_file_path) schema_file_path=args_opt.schema_file_path)
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
if args_opt.do_eval.lower() == "true": if args_opt.do_eval.lower() == "true":
if save_finetune_checkpoint_path == "": if save_finetune_checkpoint_path == "":
@ -191,7 +190,7 @@ def run_classifier():
ds.get_dataset_size(), epoch_num, "classifier") ds.get_dataset_size(), epoch_num, "classifier")
if args_opt.do_eval.lower() == "true": if args_opt.do_eval.lower() == "true":
ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
assessment_method=assessment_method, assessment_method=assessment_method,
data_file_path=args_opt.eval_data_file_path, data_file_path=args_opt.eval_data_file_path,
schema_file_path=args_opt.schema_file_path) schema_file_path=args_opt.schema_file_path)

View File

@ -38,12 +38,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
_cur_dir = os.getcwd() _cur_dir = os.getcwd()
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
""" do train """ """ do train """
if load_checkpoint_path == "": if load_checkpoint_path == "":
raise ValueError("Pretrain model missed, finetune task must load pretrain model!") raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size() steps_per_epoch = dataset.get_dataset_size()
epoch_num = dataset.get_repeat_count()
# optimizer # optimizer
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
@ -204,10 +203,10 @@ def run_ner():
use_crf=(args_opt.use_crf.lower() == "true"), use_crf=(args_opt.use_crf.lower() == "true"),
tag_to_index=tag_to_index, dropout_prob=0.1) tag_to_index=tag_to_index, dropout_prob=0.1)
if args_opt.do_train.lower() == "true": if args_opt.do_train.lower() == "true":
ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path,
schema_file_path=args_opt.schema_file_path) schema_file_path=args_opt.schema_file_path)
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
if args_opt.do_eval.lower() == "true": if args_opt.do_eval.lower() == "true":
if save_finetune_checkpoint_path == "": if save_finetune_checkpoint_path == "":
@ -218,7 +217,7 @@ def run_ner():
ds.get_dataset_size(), epoch_num, "ner") ds.get_dataset_size(), epoch_num, "ner")
if args_opt.do_eval.lower() == "true": if args_opt.do_eval.lower() == "true":
ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path, assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path,
schema_file_path=args_opt.schema_file_path) schema_file_path=args_opt.schema_file_path)
do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path, do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path,

View File

@ -100,11 +100,12 @@ def run_pretrain():
bert_net_cfg.compute_type = mstype.float32 bert_net_cfg.compute_type = mstype.float32
ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, ds = create_bert_dataset(1, device_num, rank, args_opt.do_shuffle,
args_opt.enable_data_sink, args_opt.data_sink_steps, args_opt.enable_data_sink, args_opt.data_sink_steps,
args_opt.data_dir, args_opt.schema_dir) args_opt.data_dir, args_opt.schema_dir)
new_repeat_count = args_opt.epoch_size
if args_opt.train_steps > 0: if args_opt.train_steps > 0:
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) new_repeat_count = min(args_opt.epoch_size, args_opt.train_steps // args_opt.data_sink_steps)
netwithloss = BertNetworkWithLoss(bert_net_cfg, True) netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
if cfg.optimizer == 'Lamb': if cfg.optimizer == 'Lamb':

View File

@ -38,12 +38,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
_cur_dir = os.getcwd() _cur_dir = os.getcwd()
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
""" do train """ """ do train """
if load_checkpoint_path == "": if load_checkpoint_path == "":
raise ValueError("Pretrain model missed, finetune task must load pretrain model!") raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size() steps_per_epoch = dataset.get_dataset_size()
epoch_num = dataset.get_repeat_count()
# optimizer # optimizer
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
@ -181,10 +180,10 @@ def run_squad():
netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1) netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)
if args_opt.do_train.lower() == "true": if args_opt.do_train.lower() == "true":
ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
data_file_path=args_opt.train_data_file_path, data_file_path=args_opt.train_data_file_path,
schema_file_path=args_opt.schema_file_path) schema_file_path=args_opt.schema_file_path)
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
if args_opt.do_eval.lower() == "true": if args_opt.do_eval.lower() == "true":
if save_finetune_checkpoint_path == "": if save_finetune_checkpoint_path == "":
load_finetune_checkpoint_dir = _cur_dir load_finetune_checkpoint_dir = _cur_dir
@ -194,7 +193,7 @@ def run_squad():
ds.get_dataset_size(), epoch_num, "squad") ds.get_dataset_size(), epoch_num, "squad")
if args_opt.do_eval.lower() == "true": if args_opt.do_eval.lower() == "true":
ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
data_file_path=args_opt.eval_data_file_path, data_file_path=args_opt.eval_data_file_path,
schema_file_path=args_opt.schema_file_path, is_training=False) schema_file_path=args_opt.schema_file_path, is_training=False)
do_eval(ds, args_opt.vocab_file_path, args_opt.eval_json_path, do_eval(ds, args_opt.vocab_file_path, args_opt.eval_json_path,

View File

@ -54,7 +54,6 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
ds = ds.map(input_columns="input_ids", operations=type_cast_op) ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations # apply batch operations
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(max(new_repeat_count, repeat_count))
logger.info("data size: {}".format(ds.get_dataset_size())) logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count())) logger.info("repeatcount: {}".format(ds.get_repeat_count()))
return ds, new_repeat_count return ds, new_repeat_count

View File

@ -17,7 +17,6 @@
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as deC import mindspore.dataset.transforms.c_transforms as deC
from mindspore import log as logger
from .config import transformer_net_cfg from .config import transformer_net_cfg
def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", enable_data_sink="true", def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", enable_data_sink="true",
@ -42,7 +41,4 @@ def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True) ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(repeat_count) ds = ds.repeat(repeat_count)
ds.channel_name = 'transformer' return ds
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
return ds, repeat_count

View File

@ -125,10 +125,10 @@ def run_transformer_train():
else: else:
device_num = 1 device_num = 1
rank_id = 0 rank_id = 0
dataset, repeat_count = create_transformer_dataset(epoch_count=args.epoch_size, rank_size=device_num, dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num,
rank_id=rank_id, do_shuffle=args.do_shuffle, rank_id=rank_id, do_shuffle=args.do_shuffle,
enable_data_sink=args.enable_data_sink, enable_data_sink=args.enable_data_sink,
dataset_path=args.data_path) dataset_path=args.data_path)
netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True) netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)
@ -165,7 +165,7 @@ def run_transformer_train():
netwithgrads.set_train(True) netwithgrads.set_train(True)
model = Model(netwithgrads) model = Model(netwithgrads)
model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true")) model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
if __name__ == '__main__': if __name__ == '__main__':
run_transformer_train() run_transformer_train()

View File

@ -88,10 +88,10 @@ if __name__ == '__main__':
# create dataset # create dataset
if args_opt.net == "resnet50": if args_opt.net == "resnet50":
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=config.epoch_size, dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
batch_size=config.batch_size, target=target) batch_size=config.batch_size, target=target)
else: else:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=config.epoch_size, dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
batch_size=config.batch_size) batch_size=config.batch_size)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()

View File

@ -105,7 +105,7 @@ if __name__ == '__main__':
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
if args_opt.do_train: if args_opt.do_train:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
repeat_num=epoch_size, batch_size=config.batch_size) batch_size=config.batch_size)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)

View File

@ -91,7 +91,7 @@ def main():
loss_scale = float(args_opt.loss_scale) loss_scale = float(args_opt.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0.
dataset = create_ssd_dataset(mindrecord_file, repeat_num=args_opt.epoch_size, dataset = create_ssd_dataset(mindrecord_file, repeat_num=1,
batch_size=args_opt.batch_size, device_num=device_num, rank=rank) batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size() dataset_size = dataset.get_dataset_size()

View File

@ -83,7 +83,7 @@ if __name__ == '__main__':
mirror_mean=True) mirror_mean=True)
init() init()
dataset = vgg_create_dataset(args_opt.data_path, cfg.epoch_size) dataset = vgg_create_dataset(args_opt.data_path, 1)
batch_num = dataset.get_dataset_size() batch_num = dataset.get_dataset_size()
net = vgg16(num_classes=cfg.num_classes) net = vgg16(num_classes=cfg.num_classes)

View File

@ -63,7 +63,7 @@ def test_train(configure):
data_path = configure.data_path data_path = configure.data_path
batch_size = configure.batch_size batch_size = configure.batch_size
epochs = configure.epochs epochs = configure.epochs
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size) ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size)
print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_train.size: {}".format(ds_train.get_dataset_size()))
net_builder = ModelBuilder() net_builder = ModelBuilder()

View File

@ -67,8 +67,8 @@ def test_train_eval(config):
data_path = config.data_path data_path = config.data_path
batch_size = config.batch_size batch_size = config.batch_size
epochs = config.epochs epochs = config.epochs
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size) ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size)
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size) ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size)
print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

View File

@ -85,14 +85,14 @@ def train_and_eval(config):
if config.full_batch: if config.full_batch:
context.set_auto_parallel_context(full_batch=True) context.set_auto_parallel_context(full_batch=True)
de.config.set_seed(1) de.config.set_seed(1)
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size*get_group_size()) batch_size=batch_size*get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size*get_group_size()) batch_size=batch_size*get_group_size())
else: else:
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

View File

@ -74,9 +74,9 @@ def train_and_eval(config):
batch_size = config.batch_size batch_size = config.batch_size
epochs = config.epochs epochs = config.epochs
print("epochs is {}".format(epochs)) print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

View File

@ -121,7 +121,7 @@ def main():
loss_scale = float(args_opt.loss_scale) loss_scale = float(args_opt.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0. # When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
dataset = create_yolo_dataset(mindrecord_file, repeat_num=args_opt.epoch_size, dataset = create_yolo_dataset(mindrecord_file,
batch_size=args_opt.batch_size, device_num=device_num, rank=rank) batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size() dataset_size = dataset.get_dataset_size()
print("Create dataset done!") print("Create dataset done!")

View File

@ -50,13 +50,20 @@ class MindData:
def input_indexs(self): def input_indexs(self):
return self._input_indexs return self._input_indexs
def device_que(self): def device_que(self, send_epoch_end=True):
self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736' self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736'
self.send_epoch_end = send_epoch_end
return self return self
def create_tuple_iterator(self):
return self.__iter__()
def send(self): def send(self):
pass pass
def stop_send(self):
pass
def __len__(self): def __len__(self):
return self._size return self._size

View File

@ -73,7 +73,7 @@ if __name__ == "__main__":
epoch_size = 3 epoch_size = 3
args_opt.base_size = config.crop_size args_opt.base_size = config.crop_size
args_opt.crop_size = config.crop_size args_opt.crop_size = config.crop_size
train_dataset = create_dataset(args_opt, args_opt.data_url, epoch_size, config.batch_size, train_dataset = create_dataset(args_opt, args_opt.data_url, 1, config.batch_size,
usage="train", shuffle=False) usage="train", shuffle=False)
dataset_size = train_dataset.get_dataset_size() dataset_size = train_dataset.get_dataset_size()
callback = LossCallBack(dataset_size) callback = LossCallBack(dataset_size)

View File

@ -120,10 +120,10 @@ def test_transformer():
batch_size = 96 batch_size = 96
epoch_size = 3 epoch_size = 3
config = get_config(version=version, batch_size=batch_size) config = get_config(version=version, batch_size=batch_size)
dataset, repeat_count = create_transformer_dataset(epoch_count=epoch_size, dataset = create_transformer_dataset(epoch_count=1,
do_shuffle="false", do_shuffle="false",
enable_data_sink="false", enable_data_sink="false",
dataset_path=DATA_DIR) dataset_path=DATA_DIR)
netwithloss = TransformerNetworkWithLoss(config, True) netwithloss = TransformerNetworkWithLoss(config, True)
@ -146,7 +146,7 @@ def test_transformer():
netwithgrads.set_train(True) netwithgrads.set_train(True)
time_monitor_callback = TimeMonitor(dataset.get_dataset_size()) time_monitor_callback = TimeMonitor(dataset.get_dataset_size())
model = Model(netwithgrads) model = Model(netwithgrads)
model.train(repeat_count, dataset, callbacks=[time_monitor_callback, callback], dataset_sink_mode=False) model.train(epoch_size, dataset, callbacks=[time_monitor_callback, callback], dataset_sink_mode=False)
# assertion occurs while the loss value, overflow state or loss_scale value is wrong # assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list) loss_value = np.array(callback.loss_list)

View File

@ -79,9 +79,9 @@ def test_train_eval():
batch_size = config.batch_size batch_size = config.batch_size
epochs = config.epochs epochs = config.epochs
print("epochs is {}".format(epochs)) print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size, ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size,
data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size()) data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size, ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size,
data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size()) data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

View File

@ -76,9 +76,9 @@ def test_train_eval():
batch_size = config.batch_size batch_size = config.batch_size
epochs = config.epochs epochs = config.epochs
print("epochs is {}".format(epochs)) print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

View File

@ -113,7 +113,7 @@ def test_yolov3():
loss_scale = float(loss_scale) loss_scale = float(loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0. # When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
dataset = create_yolo_dataset(mindrecord_file, repeat_num=epoch_size, dataset = create_yolo_dataset(mindrecord_file, repeat_num=1,
batch_size=batch_size, device_num=device_num, rank=rank) batch_size=batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size() dataset_size = dataset.get_dataset_size()
print("Create dataset done!") print("Create dataset done!")
@ -146,12 +146,12 @@ def test_yolov3():
assert loss_value[2] < expect_loss_value[2] assert loss_value[2] < expect_loss_value[2]
epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2] epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2]
expect_epoch_mseconds = 950 expect_epoch_mseconds = 2000
print("epoch mseconds: {}".format(epoch_mseconds)) print("epoch mseconds: {}".format(epoch_mseconds))
assert epoch_mseconds <= expect_epoch_mseconds assert epoch_mseconds <= expect_epoch_mseconds
per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2] per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2]
expect_per_step_mseconds = 110 expect_per_step_mseconds = 220
print("per step mseconds: {}".format(per_step_mseconds)) print("per step mseconds: {}".format(per_step_mseconds))
assert per_step_mseconds <= expect_per_step_mseconds assert per_step_mseconds <= expect_per_step_mseconds
print("yolov3 test case passed.") print("yolov3 test case passed.")

View File

@ -91,6 +91,7 @@ def me_de_train_dataset(sink_mode=False):
"""test me de train dataset""" """test me de train dataset"""
# apply repeat operations # apply repeat operations
repeat_count = 1 repeat_count = 1
sink_size = -1
batch_size = 16 batch_size = 16
ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids",
"next_sentence_labels", "masked_lm_positions", "next_sentence_labels", "masked_lm_positions",
@ -99,9 +100,9 @@ def me_de_train_dataset(sink_mode=False):
new_repeat_count = repeat_count new_repeat_count = repeat_count
if sink_mode: if sink_mode:
repeat_count = 30 repeat_count = 30
sink_steps = 100 sink_size = 100
ori_dataaet_size = ds.get_dataset_size() ori_dataaet_size = ds.get_dataset_size()
new_size = sink_steps * batch_size new_size = sink_size * batch_size
ds.set_dataset_size(new_size) ds.set_dataset_size(new_size)
new_repeat_count = int(repeat_count * ori_dataaet_size // ds.get_dataset_size()) new_repeat_count = int(repeat_count * ori_dataaet_size // ds.get_dataset_size())
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
@ -112,10 +113,9 @@ def me_de_train_dataset(sink_mode=False):
ds = ds.map(input_columns="input_ids", operations=type_cast_op) ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations # apply batch operations
ds = ds.batch(batch_size, drop_remainder=True) ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_count)
logger.info("data size: {}".format(ds.get_dataset_size())) logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeat_count: {}".format(ds.get_repeat_count())) logger.info("repeat_count: {}".format(ds.get_repeat_count()))
return ds, new_repeat_count return ds, new_repeat_count, sink_size
def weight_variable(shape): def weight_variable(shape):
@ -157,7 +157,7 @@ class TimeMonitor(Callback):
def test_bert_percision(): def test_bert_percision():
"""test bert percision""" """test bert percision"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
ds, new_repeat_count = me_de_train_dataset() ds, new_repeat_count, _ = me_de_train_dataset()
version = os.getenv('VERSION', 'large') version = os.getenv('VERSION', 'large')
batch_size = 16 batch_size = 16
config = get_config(version=version, batch_size=batch_size) config = get_config(version=version, batch_size=batch_size)
@ -215,7 +215,7 @@ def test_bert_percision():
def test_bert_performance(): def test_bert_performance():
"""test bert performance""" """test bert performance"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
ds, new_repeat_count = me_de_train_dataset(sink_mode=True) ds, new_repeat_count, sink_size = me_de_train_dataset(sink_mode=True)
version = os.getenv('VERSION', 'large') version = os.getenv('VERSION', 'large')
batch_size = 16 batch_size = 16
config = get_config(version=version, batch_size=batch_size) config = get_config(version=version, batch_size=batch_size)
@ -251,7 +251,7 @@ def test_bert_performance():
param.default_input = weight_variable(value.asnumpy().shape) param.default_input = weight_variable(value.asnumpy().shape)
time_monitor_callback = TimeMonitor(ds.get_dataset_size()) time_monitor_callback = TimeMonitor(ds.get_dataset_size())
model.train(new_repeat_count, ds, callbacks=[time_monitor_callback, callback], model.train(new_repeat_count, ds, callbacks=[time_monitor_callback, callback],
dataset_sink_mode=True) dataset_sink_mode=True, sink_size=sink_size)
# assertion occurs while the loss value, overflow state or loss_scale value is wrong # assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list) loss_value = np.array(callback.loss_list)

View File

@ -79,7 +79,7 @@ def test_deeplabv3_1p():
args_opt.base_size = config.crop_size args_opt.base_size = config.crop_size
args_opt.crop_size = config.crop_size args_opt.crop_size = config.crop_size
args_opt.batch_size = config.batch_size args_opt.batch_size = config.batch_size
train_dataset = create_dataset(args_opt, data_url, epoch_size, config.batch_size, train_dataset = create_dataset(args_opt, data_url, 1, config.batch_size,
usage="eval") usage="eval")
dataset_size = train_dataset.get_dataset_size() dataset_size = train_dataset.get_dataset_size()
callback = LossCallBack(dataset_size) callback = LossCallBack(dataset_size)

View File

@ -155,7 +155,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
# train dataset # train dataset
dataset = create_dataset(dataset_path=dataset_path, do_train=True, dataset = create_dataset(dataset_path=dataset_path, do_train=True,
repeat_num=epoch_size, batch_size=config.batch_size) repeat_num=1, batch_size=config.batch_size)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
eval_interval = config.eval_interval eval_interval = config.eval_interval
@ -163,7 +163,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
# evalutation dataset # evalutation dataset
eval_dataset = create_dataset(dataset_path=eval_path, do_train=False, eval_dataset = create_dataset(dataset_path=eval_path, do_train=False,
repeat_num=epoch_size, batch_size=config.eval_batch_size) repeat_num=1, batch_size=config.eval_batch_size)
# loss scale # loss scale
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
@ -260,14 +260,14 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
# train dataset # train dataset
dataset = create_dataset(dataset_path=dataset_path, do_train=True, dataset = create_dataset(dataset_path=dataset_path, do_train=True,
repeat_num=epoch_size, batch_size=thor_config.batch_size) repeat_num=1, batch_size=thor_config.batch_size)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
eval_interval = thor_config.eval_interval eval_interval = thor_config.eval_interval
# evalutation dataset # evalutation dataset
eval_dataset = create_dataset(dataset_path=eval_path, do_train=False, eval_dataset = create_dataset(dataset_path=eval_path, do_train=False,
repeat_num=epoch_size, batch_size=thor_config.eval_batch_size) repeat_num=1, batch_size=thor_config.eval_batch_size)
# loss scale # loss scale
loss_scale = FixedLossScaleManager(thor_config.loss_scale, drop_overflow_update=False) loss_scale = FixedLossScaleManager(thor_config.loss_scale, drop_overflow_update=False)

View File

@ -136,7 +136,7 @@ if __name__ == '__main__':
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
if args_opt.do_train: if args_opt.do_train:
dataset = create_dataset(epoch_size) dataset = create_dataset(1)
batch_num = dataset.get_dataset_size() batch_num = dataset.get_dataset_size()
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=10) config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck) ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck)

View File

@ -140,7 +140,7 @@ def train_process(epoch_size, num_classes, batch_size):
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
dataset = create_dataset(epoch_size, training=True, batch_size=batch_size) dataset = create_dataset(1, training=True, batch_size=batch_size)
loss_cb = LossGet() loss_cb = LossGet()
model.train(epoch_size, dataset, callbacks=[loss_cb]) model.train(epoch_size, dataset, callbacks=[loss_cb])

View File

@ -164,7 +164,7 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size,
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
dataset = create_dataset(epoch_size, training=True, dataset = create_dataset(1, training=True,
batch_size=batch_size, rank_id=device_id, rank_size=device_num, batch_size=batch_size, rank_id=device_id, rank_size=device_num,
enable_hccl=enable_hccl) enable_hccl=enable_hccl)

View File

@ -91,8 +91,9 @@ SET(DE_UT_SRCS
cyclic_array_test.cc cyclic_array_test.cc
perf_data_test.cc perf_data_test.cc
c_api_test.cc c_api_test.cc
tensor_op_fusion_pass_test.cc tensor_op_fusion_pass_test.cc
sliding_window_op_test.cc sliding_window_op_test.cc
epoch_ctrl_op_test.cc
) )
add_executable(de_ut_tests ${DE_UT_SRCS}) add_executable(de_ut_tests ${DE_UT_SRCS})

View File

@ -397,23 +397,21 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true); std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true);
std::shared_ptr<CacheMergeOp> myMergeOp; // In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op.
rc = CacheMergeOp::Builder().SetNumWorkers(3).SetOpConnectorSize(3).SetNumCleaner(2).SetClient(myClient).Build( // Rather than manually build this, the way to do it is to choose the position of the cache in the tree by
&myMergeOp); // adding a CacheOp. Then, the tree prepare code will drive a transform that will remove the CacheOp and
EXPECT_TRUE(rc.IsOk()); // replace it with the required tree structures for cache lookup op and cache merge op.
std::shared_ptr<CacheLookupOp> myLookupOp; std::shared_ptr<CacheOp> myCacheOp;
rc = CacheLookupOp::Builder() rc = CacheOp::Builder()
.SetNumWorkers(3) .SetNumWorkers(4)
.SetOpConnectorSize(3)
.SetClient(myClient) .SetClient(myClient)
.SetSampler(seq_sampler) .SetRowsPerBuffer(3)
.Build(&myLookupOp); .Build(&myCacheOp);
EXPECT_TRUE(rc.IsOk());
std::shared_ptr<ImageFolderOp> so; std::shared_ptr<ImageFolderOp> so;
ImageFolderOp::Builder builder; ImageFolderOp::Builder builder;
builder.SetSampler(myLookupOp) builder.SetSampler(std::move(seq_sampler))
.SetOpConnectorSize(3) .SetOpConnectorSize(3)
.SetNumWorkers(3) .SetNumWorkers(3)
.SetRowsPerBuffer(2) .SetRowsPerBuffer(2)
@ -432,20 +430,18 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
auto myTree = std::make_shared<ExecutionTree>(); auto myTree = std::make_shared<ExecutionTree>();
rc = myTree->AssociateNode(so); rc = myTree->AssociateNode(so);
EXPECT_TRUE(rc.IsOk()); EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myLookupOp);
EXPECT_TRUE(rc.IsOk()); rc = myTree->AssociateNode(myCacheOp);
rc = myTree->AssociateNode(myMergeOp);
EXPECT_TRUE(rc.IsOk()); EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRepeatOp); rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk()); EXPECT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp); rc = myTree->AssignRoot(myRepeatOp);
EXPECT_TRUE(rc.IsOk()); EXPECT_TRUE(rc.IsOk());
rc = myRepeatOp->AddChild(myMergeOp); rc = myRepeatOp->AddChild(myCacheOp);
EXPECT_TRUE(rc.IsOk()); EXPECT_TRUE(rc.IsOk());
rc = myMergeOp->AddChild(myLookupOp); rc = myCacheOp->AddChild(so);
EXPECT_TRUE(rc.IsOk());
rc = myMergeOp->AddChild(so);
EXPECT_TRUE(rc.IsOk()); EXPECT_TRUE(rc.IsOk());
rc = myTree->Prepare(); rc = myTree->Prepare();

View File

@ -0,0 +1,639 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/core/client.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include <memory>
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path,
bool shuf = false, std::shared_ptr<Sampler> sampler = nullptr,
std::map<std::string, int32_t> map = {}, bool decode = false);
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
class MindDataTestEpochCtrlOp : public UT::DatasetOpTesting {
public:
void SetUp() override {
DatasetOpTesting::SetUp();
folder_path = datasets_root_path_ + "/testPK/data";
GlobalInit();
// Start with an empty execution tree
my_tree_ = std::make_shared<ExecutionTree>();
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)});
rc = my_tree_->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree_);
TensorMap tensor_map;
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
int32_t i = 0;
while (tensor_map.size() != 0) {
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
// Dump all the image into string, to be used as a comparison later.
golden_imgs.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size());
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
i++;
}
}
std::shared_ptr<ExecutionTree> my_tree_;
Status rc;
std::string golden_imgs;
std::string folder_path;
int32_t label = 0;
std::string result;
int32_t img_class[4] = {0, 1, 2, 3};
};
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_AutoInjectEpoch) {
MS_LOG(WARNING) << "Doing ImageFolder_AutoInjectEpoch.";
int32_t num_epoch = 2 + std::rand() % 5;
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)});
rc = my_tree_->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Launch();
EXPECT_TRUE(rc.IsOk());
MS_LOG(DEBUG) << "num_epoch: " << num_epoch;
std::string golden = golden_imgs;
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree_);
TensorMap tensor_map;
uint64_t i = 0;
for (int epoch = 0; epoch < num_epoch; epoch++) {
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
while (tensor_map.size() != 0) {
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
// Dump all the image into string, to be used as a comparison later.
result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size());
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
i++;
}
EXPECT_TRUE(result == golden);
result.clear();
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
}
EXPECT_TRUE(i == 44 * num_epoch);
// Try to fetch data beyond the specified number of epochs.
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
}
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch) {
MS_LOG(WARNING) << "Doing ImageFolder_Epoch.";
int32_t num_epoch = 2 + std::rand() % 5;
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)});
rc = my_tree_->Prepare(num_epoch);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Launch();
EXPECT_TRUE(rc.IsOk());
MS_LOG(DEBUG) << "num_epoch: " << num_epoch;
std::string golden = golden_imgs;
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree_);
TensorMap tensor_map;
uint64_t i = 0;
for (int epoch = 0; epoch < num_epoch; epoch++) {
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
while (tensor_map.size() != 0) {
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
// Dump all the image into string, to be used as a comparison later.
result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size());
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
i++;
}
EXPECT_TRUE(result == golden);
result.clear();
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
}
EXPECT_TRUE(i == 44 * num_epoch);
// Try to fetch data beyond the specified number of epochs.
rc = di.GetNextAsMap(&tensor_map);
EXPECT_FALSE(rc.IsOk());
}
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Epoch) {
MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch.";
int32_t num_epoch = 2 + std::rand() % 5;
int32_t num_repeats = 2;
std::shared_ptr<RepeatOp> repeat_op;
rc = RepeatOp::Builder(num_repeats).Build(&repeat_op);
EXPECT_TRUE(rc.IsOk());
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op});
rc = my_tree_->Prepare(num_epoch);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Launch();
EXPECT_TRUE(rc.IsOk());
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats;
std::string golden = golden_imgs;
for (int i = 1; i < num_repeats; i++) {
golden += golden_imgs;
}
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree_);
TensorMap tensor_map;
uint64_t i = 0;
for (int epoch = 0; epoch < num_epoch; epoch++) {
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
while (tensor_map.size() != 0) {
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
// Dump all the image into string, to be used as a comparison later.
result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size());
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
i++;
}
EXPECT_TRUE(result == golden);
result.clear();
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
}
EXPECT_TRUE(i == 44 * num_repeats * num_epoch);
// Try to fetch data beyond the specified number of epochs.
rc = di.GetNextAsMap(&tensor_map);
EXPECT_FALSE(rc.IsOk());
}
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Repeat_Epoch) {
MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Repeat_Epoch.";
int32_t num_epoch = 2 + std::rand() % 5;
int32_t num_repeats = 2;
std::shared_ptr<RepeatOp> repeat_op;
rc = RepeatOp::Builder(num_repeats).Build(&repeat_op);
EXPECT_TRUE(rc.IsOk());
int32_t num_repeats_2 = 3;
std::shared_ptr<RepeatOp> repeat_op_2;
rc = RepeatOp::Builder(num_repeats_2).Build(&repeat_op_2);
EXPECT_TRUE(rc.IsOk());
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op, repeat_op_2});
rc = my_tree_->Prepare(num_epoch);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Launch();
EXPECT_TRUE(rc.IsOk());
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats << ". num_repeat_2: " << num_repeats_2;
std::string golden;
for (int j = 0; j < num_repeats_2; j++) {
for (int i = 0; i < num_repeats; i++) {
golden += golden_imgs;
}
}
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree_);
TensorMap tensor_map;
uint64_t i = 0;
for (int epoch = 0; epoch < num_epoch; epoch++) {
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
while (tensor_map.size() != 0) {
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
// Dump all the image into string, to be used as a comparison later.
result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size());
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
i++;
}
EXPECT_EQ(result.size(), golden.size());
EXPECT_TRUE(result == golden);
result.clear();
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
}
EXPECT_EQ(i, 44 * num_epoch * num_repeats * num_repeats_2);
// Try to fetch data beyond the specified number of epochs.
rc = di.GetNextAsMap(&tensor_map);
EXPECT_FALSE(rc.IsOk());
}
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch_Inf) {
MS_LOG(WARNING) << "Doing ImageFolder_Epoch_Inf.";
// if num_epoch == -1, it means infinity.
int32_t num_epoch = -1;
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)});
rc = my_tree_->Prepare(num_epoch);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree_);
TensorMap tensor_map;
uint64_t i = 0;
// For this test, we stop at stop_at_epoch number.
int32_t stop_at_epoch = 2 + std::rand() % 6;
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch;
for (int epoch = 0; epoch < stop_at_epoch; epoch++) {
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
while (tensor_map.size() != 0) {
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
// Dump all the image into string, to be used as a comparison later.
result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size());
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
i++;
}
EXPECT_EQ(result, golden_imgs);
result.clear();
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
}
EXPECT_TRUE(i == 44 * stop_at_epoch);
}
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Repeat_Epoch_Inf) {
MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch_Inf.";
// if num_epoch == -1, it means infinity.
int32_t num_epoch = -1;
int32_t num_repeats = 2;
std::shared_ptr<RepeatOp> repeat_op;
rc = RepeatOp::Builder(num_repeats).Build(&repeat_op);
EXPECT_TRUE(rc.IsOk());
int32_t num_repeats_2 = 3;
std::shared_ptr<RepeatOp> repeat_op_2;
rc = RepeatOp::Builder(num_repeats_2).Build(&repeat_op_2);
EXPECT_TRUE(rc.IsOk());
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op, repeat_op_2});
rc = my_tree_->Prepare(num_epoch);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Launch();
EXPECT_TRUE(rc.IsOk());
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats << ". num_repeat_2: " << num_repeats_2;
std::string golden;
for (int j = 0; j < num_repeats_2; j++) {
for (int i = 0; i < num_repeats; i++) {
golden += golden_imgs;
}
}
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree_);
TensorMap tensor_map;
uint64_t i = 0;
// For this test, we stop at stop_at_epoch number.
int32_t stop_at_epoch = 2 + std::rand() % 6;
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch;
for (int epoch = 0; epoch < stop_at_epoch; epoch++) {
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
while (tensor_map.size() != 0) {
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
// Dump all the image into string, to be used as a comparison later.
result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size());
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
i++;
}
EXPECT_EQ(result, golden);
result.clear();
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
}
EXPECT_TRUE(i == 44 * stop_at_epoch * num_repeats * num_repeats_2);
}
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch_ChildItr) {
MS_LOG(WARNING) << "Doing ImageFolder_Epoch_ChildItr.";
int32_t num_epoch = 2 + std::rand() % 5;
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)});
rc = my_tree_->Prepare(num_epoch);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Launch();
EXPECT_TRUE(rc.IsOk());
MS_LOG(INFO) << "num_epoch: " << num_epoch;
// Start the loop of reading tensors from our pipeline
ChildIterator ci(my_tree_->root().get(), 0, 0);
TensorRow tensor_row;
uint64_t total_sample = 0;
uint64_t i = 0;
uint32_t epoch = 0;
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(rc.IsOk());
while(!ci.eof_handled()) {
i = 0;
while (tensor_row.size() != 0) {
tensor_row[1]->GetItemAt<int32_t>(&label, {});
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
// Dump all the image into string, to be used as a comparison later.
result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size());
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(rc.IsOk());
i++;
}
epoch++;
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
EXPECT_TRUE(result == golden_imgs);
result.clear();
EXPECT_TRUE(i == 44);
total_sample += i;
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(rc.IsOk());
}
EXPECT_TRUE(total_sample == 44 * num_epoch);
// Try to fetch data after last epoch ends.
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(tensor_row.empty());
EXPECT_FALSE(rc.IsOk());
}
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Epoch_ChildItr) {
MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch_ChildItr.";
int32_t num_epoch = 2 + std::rand() % 5;
int32_t num_repeats = 2;
std::shared_ptr<RepeatOp> repeat_op;
rc = RepeatOp::Builder(num_repeats).Build(&repeat_op);
EXPECT_TRUE(rc.IsOk());
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op});
rc = my_tree_->Prepare(num_epoch);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Launch();
EXPECT_TRUE(rc.IsOk());
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats;
std::string golden;
for (int i = 0; i < num_repeats; i++) {
golden += golden_imgs;
}
// Start the loop of reading tensors from our pipeline
ChildIterator ci(my_tree_->root().get(), 0, 0);
TensorRow tensor_row;
uint64_t total_sample = 0;
uint64_t i = 0;
uint32_t epoch = 0;
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(rc.IsOk());
while(!ci.eof_handled()) {
i = 0;
while (tensor_row.size() != 0) {
tensor_row[1]->GetItemAt<int32_t>(&label, {});
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
// Dump all the image into string, to be used as a comparison later.
result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size());
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(rc.IsOk());
i++;
}
epoch++;
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
EXPECT_TRUE(result == golden);
result.clear();
EXPECT_TRUE(i == 44 * num_repeats);
total_sample += i;
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(rc.IsOk());
}
EXPECT_TRUE(total_sample == 44 * num_epoch * num_repeats);
// Try to fetch data after last epoch ends.
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(tensor_row.empty());
EXPECT_FALSE(rc.IsOk());
}
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Repeat_Epoch_ChildItr) {
MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Repeat_Epoch_ChildItr.";
int32_t num_epoch = 2 + std::rand() % 5;
int32_t num_repeats = 2;
std::shared_ptr<RepeatOp> repeat_op;
rc = RepeatOp::Builder(num_repeats).Build(&repeat_op);
EXPECT_TRUE(rc.IsOk());
int32_t num_repeats_2 = 3;
std::shared_ptr<RepeatOp> repeat_op_2;
rc = RepeatOp::Builder(num_repeats_2).Build(&repeat_op_2);
EXPECT_TRUE(rc.IsOk());
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op, repeat_op_2});
rc = my_tree_->Prepare(num_epoch);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Launch();
EXPECT_TRUE(rc.IsOk());
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats << ". num_repeat_2: " << num_repeats_2;
std::string golden;
for (int j = 0; j < num_repeats_2; j++) {
for (int i = 0; i < num_repeats; i++) {
golden += golden_imgs;
}
}
// Start the loop of reading tensors from our pipeline
ChildIterator ci(my_tree_->root().get(), 0, 0);
TensorRow tensor_row;
uint64_t total_sample = 0;
uint64_t i = 0;
uint32_t epoch = 0;
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(rc.IsOk());
while(!ci.eof_handled()) {
i = 0;
while (tensor_row.size() != 0) {
tensor_row[1]->GetItemAt<int32_t>(&label, {});
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
// Dump all the image into string, to be used as a comparison later.
result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size());
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(rc.IsOk());
i++;
}
epoch++;
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
EXPECT_TRUE(result == golden);
result.clear();
EXPECT_TRUE(i == 44 * num_repeats * num_repeats_2);
total_sample += i;
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(rc.IsOk());
}
EXPECT_TRUE(total_sample == 44 * num_epoch * num_repeats * num_repeats_2);
// Try to fetch data after last epoch ends.
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(tensor_row.empty());
EXPECT_FALSE(rc.IsOk());
}
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch_Inf_ChildItr) {
MS_LOG(WARNING) << "Doing ImageFolder_Epoch_Inf_ChildItr.";
// if num_epoch == -1, it means infinity.
int32_t num_epoch = -1;
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)});
rc = my_tree_->Prepare(num_epoch);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
ChildIterator ci(my_tree_->root().get(), 0, 0);
TensorRow tensor_row;
uint64_t i = 0;
// For this test, we stop at a random number between 0 - 100 epochs.
int32_t stop_at_epoch = 2 + std::rand() % 5;
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch;
for (int epoch = 0; epoch < stop_at_epoch; epoch++) {
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(rc.IsOk());
while (tensor_row.size() != 0) {
tensor_row[1]->GetItemAt<int32_t>(&label, {});
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
// Dump all the image into string, to be used as a comparison later.
result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size());
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(rc.IsOk());
i++;
}
EXPECT_TRUE(result == golden_imgs);
result.clear();
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
}
EXPECT_TRUE(i == 44 * stop_at_epoch);
}
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Epoch_Inf_ChildItr) {
MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch_Inf_ChildItr.";
// if num_epoch == -1, it means infinity.
int32_t num_epoch = -1;
int32_t num_repeats = 2;
std::shared_ptr<RepeatOp> repeat_op;
rc = RepeatOp::Builder(num_repeats).Build(&repeat_op);
EXPECT_TRUE(rc.IsOk());
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op});
rc = my_tree_->Prepare(num_epoch);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Launch();
EXPECT_TRUE(rc.IsOk());
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats;
std::string golden;
for (int i = 0; i < num_repeats; i++) {
golden += golden_imgs;
}
// Start the loop of reading tensors from our pipeline
ChildIterator ci(my_tree_->root().get(), 0, 0);
TensorRow tensor_row;
uint64_t i = 0;
// For this test, we stop at a random number between 0 - 100 epochs.
int32_t stop_at_epoch = 2 + std::rand() % 5;
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch;
for (int epoch = 0; epoch < stop_at_epoch; epoch++) {
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(rc.IsOk());
while (tensor_row.size() != 0) {
tensor_row[1]->GetItemAt<int32_t>(&label, {});
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
// Dump all the image into string, to be used as a comparison later.
result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size());
rc = ci.FetchNextTensorRow(&tensor_row);
EXPECT_TRUE(rc.IsOk());
i++;
}
EXPECT_TRUE(result == golden);
result.clear();
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
}
EXPECT_TRUE(i == 44 * stop_at_epoch * num_repeats);
}

View File

@ -46,7 +46,8 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) {
ASSERT_TRUE(rc.IsOk()); ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_tfreader_op); rc = my_tree->AssociateNode(my_tfreader_op);
ASSERT_TRUE(rc.IsOk()); ASSERT_TRUE(rc.IsOk());
my_tree->AssociateNode(parent_op); rc = my_tree->AssociateNode(parent_op);
ASSERT_TRUE(rc.IsOk());
ASSERT_NE(parent_op, nullptr); ASSERT_NE(parent_op, nullptr);
ASSERT_NE(my_tfreader_op, nullptr); ASSERT_NE(my_tfreader_op, nullptr);
parent_op->AddChild(std::move(my_tfreader_op)); parent_op->AddChild(std::move(my_tfreader_op));

View File

@ -104,9 +104,11 @@ def test_cache_map_basic3():
decode_op = c_vision.Decode() decode_op = c_vision.Decode()
ds1 = ds1.repeat(4) ds1 = ds1.repeat(4)
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
print("ds1.dataset_size is ", ds1.get_dataset_size())
num_iter = 0 num_iter = 0
for _ in ds1.create_dict_iterator(): for _ in ds1.create_dict_iterator():
print("get data from dataset")
num_iter += 1 num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter)) logger.info("Number of data in ds1: {} ".format(num_iter))
@ -152,6 +154,10 @@ def test_cache_map_failure1():
if __name__ == '__main__': if __name__ == '__main__':
test_cache_map_basic1() test_cache_map_basic1()
print("test_cache_map_basic1 success.")
test_cache_map_basic2() test_cache_map_basic2()
print("test_cache_map_basic2 success.")
test_cache_map_basic3() test_cache_map_basic3()
print("test_cache_map_basic3 success.")
test_cache_map_failure1() test_cache_map_failure1()
print("test_cache_map_failure1 success.")

View File

@ -238,7 +238,7 @@ def test_tfrecord_shard_equal_rows():
def test_tfrecord_no_schema_columns_list(): def test_tfrecord_no_schema_columns_list():
logger.info("test_tfrecord_no_schema_columns_list") logger.info("test_tfrecord_no_schema_columns_list")
data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"]) data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"])
row = data.create_dict_iterator().get_next() row = data.create_dict_iterator().__next__()
assert row["col_sint16"] == [-32768] assert row["col_sint16"] == [-32768]
with pytest.raises(KeyError) as info: with pytest.raises(KeyError) as info:
@ -258,7 +258,7 @@ def test_tfrecord_schema_columns_list():
schema.add_column('col_sint32', de_type=mstype.int64, shape=[1]) schema.add_column('col_sint32', de_type=mstype.int64, shape=[1])
schema.add_column('col_sint64', de_type=mstype.int64, shape=[1]) schema.add_column('col_sint64', de_type=mstype.int64, shape=[1])
data = ds.TFRecordDataset(FILES, schema=schema, shuffle=False, columns_list=["col_sint16"]) data = ds.TFRecordDataset(FILES, schema=schema, shuffle=False, columns_list=["col_sint16"])
row = data.create_dict_iterator().get_next() row = data.create_dict_iterator().__next__()
assert row["col_sint16"] == [-32768] assert row["col_sint16"] == [-32768]
with pytest.raises(KeyError) as info: with pytest.raises(KeyError) as info:

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import time
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger from mindspore import log as logger
@ -35,6 +37,8 @@ def test_case_0():
data = data.device_que() data = data.device_que()
data.send() data.send()
time.sleep(0.1)
data.stop_send()
def test_case_1(): def test_case_1():
@ -58,6 +62,8 @@ def test_case_1():
data = data.device_que() data = data.device_que()
data.send() data.send()
time.sleep(0.1)
data.stop_send()
def test_case_2(): def test_case_2():
@ -84,6 +90,8 @@ def test_case_2():
data = data.device_que() data = data.device_que()
assert data.get_repeat_count() == 2 assert data.get_repeat_count() == 2
data.send() data.send()
time.sleep(0.1)
data.stop_send()
def test_case_3(): def test_case_3():
@ -109,13 +117,17 @@ def test_case_3():
data = data.device_que() data = data.device_que()
data.send() data.send()
time.sleep(0.1)
data.stop_send()
def test_case_tf_file(): def test_case_tf_file():
data = ds.TFRecordDataset(TF_FILES, TF_SCHEMA_FILE, shuffle=ds.Shuffle.FILES) data = ds.TFRecordDataset(TF_FILES, TF_SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
data = data.to_device(num_batch=10) data = data.to_device()
data.send() data.send()
time.sleep(0.1)
data.stop_send()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -0,0 +1,608 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing Epoch Control op in DE
"""
import itertools
import cv2
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def diff_mse(in1, in2):
"""
diff_mse
"""
mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
return mse * 100
def test_cifar10():
"""
dataset parameter
"""
logger.info("Test dataset parameter")
data_dir_10 = "../data/dataset/testCifar10Data"
num_repeat = 2
batch_size = 32
limit_dataset = 100
# apply dataset operations
data1 = ds.Cifar10Dataset(data_dir_10, limit_dataset)
data1 = data1.repeat(num_repeat)
data1 = data1.batch(batch_size, True)
num_epoch = 5
# iter1 will always assume there is a next epoch and never shutdown.
iter1 = data1.create_tuple_iterator()
epoch_count = 0
sample_count = 0
for _ in range(num_epoch):
row_count = 0
for _ in iter1:
# in this example, each dictionary has keys "image" and "label"
row_count += 1
assert row_count == int(limit_dataset * num_repeat / batch_size)
logger.debug("row_count: ", row_count)
epoch_count += 1
sample_count += row_count
assert epoch_count == num_epoch
logger.debug("total epochs: ", epoch_count)
assert sample_count == int(limit_dataset * num_repeat / batch_size) * num_epoch
logger.debug("total sample: ", sample_count)
def test_decode_op():
"""
Test Decode op
"""
logger.info("test_decode_op")
# Decode with rgb format set to True
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
# Serialize and Load dataset requires using vision.Decode instead of vision.Decode().
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
# Second dataset
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
num_epoch = 5
# iter1 will always assume there is a next epoch and never shutdown.
iter1 = data1.create_dict_iterator()
# iter 2 will stop and shutdown pipeline after num_epoch
iter2 = data2.create_dict_iterator(num_epoch)
for _ in range(num_epoch):
i = 0
for item1, item2 in itertools.zip_longest(iter1, iter2):
actual = item1["image"]
expected = cv2.imdecode(item2["image"], cv2.IMREAD_COLOR)
expected = cv2.cvtColor(expected, cv2.COLOR_BGR2RGB)
assert actual.shape == expected.shape
diff = actual - expected
mse = np.sum(np.power(diff, 2))
assert mse == 0
i = i + 1
assert i == 3
# Users have the option to manually stop the iterator, or rely on garbage collector.
iter1.stop()
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value)
with pytest.raises(RuntimeError) as info:
iter2.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
# Generate 1d int numpy array from 0 - 63
def generator_1d():
"""
generator
"""
for i in range(64):
yield (np.array([i]),)
def test_generator_dict_0():
"""
test generator dict 0
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
i = 0
# create the iterator inside the loop declaration
for item in data1.create_dict_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
def test_generator_dict_1():
"""
test generator dict 1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
for _ in range(10):
i = 0
# BAD. Do not create iterator every time inside.
# Create iterator outside the epoch for loop.
for item in data1.create_dict_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
def test_generator_dict_2():
"""
test generator dict 2
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_dict_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
# iter1 is still alive and running.
item1 = iter1.__next__()
assert item1
# rely on garbage collector to destroy iter1
def test_generator_dict_3():
"""
test generator dict 3
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_dict_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
# optional
iter1.stop()
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value)
def test_generator_dict_4():
"""
test generator dict 4
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_dict_iterator(num_epochs=10)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_dict_4_1():
"""
test generator dict 4_1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
# epoch ctrl op will not be injected if num_epochs is 1.
iter1 = data1.create_dict_iterator(num_epochs=1)
for _ in range(1):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_dict_4_2():
"""
test generator dict 4_2
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
# repeat will not be injected when num repeat is 1.
data1 = data1.repeat(1)
# epoch ctrl op will not be injected if num_epochs is 1.
iter1 = data1.create_dict_iterator(num_epochs=1)
for _ in range(1):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_dict_5():
"""
test generator dict 5
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_dict_iterator(num_epochs=11)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
# still one more epoch left in the iter1.
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
# now iter1 has been exhausted, c++ pipeline has been shut down.
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
# Test tuple iterator
def test_generator_tuple_0():
"""
test generator tuple 0
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
i = 0
# create the iterator inside the loop declaration
for item in data1.create_tuple_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item[0], golden)
i = i + 1
def test_generator_tuple_1():
"""
test generator tuple 1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
for _ in range(10):
i = 0
# BAD. Do not create iterator every time inside.
# Create iterator outside the epoch for loop.
for item in data1.create_tuple_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64
def test_generator_tuple_2():
"""
test generator tuple 2
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_tuple_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64
# iter1 is still alive and running.
item1 = iter1.__next__()
assert item1
# rely on garbage collector to destroy iter1
def test_generator_tuple_3():
"""
test generator tuple 3
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_tuple_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64
# optional
iter1.stop()
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value)
def test_generator_tuple_4():
"""
test generator tuple 4
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_tuple_iterator(num_epochs=10)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_tuple_5():
"""
test generator tuple 5
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_tuple_iterator(num_epochs=11)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64
# still one more epoch left in the iter1.
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64
# now iter1 has been exhausted, c++ pipeline has been shut down.
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
# Test with repeat
def test_generator_tuple_repeat_1():
"""
test generator tuple repeat 1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
iter1 = data1.create_tuple_iterator(num_epochs=11)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2
# still one more epoch left in the iter1.
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2
# now iter1 has been exhausted, c++ pipeline has been shut down.
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
# Test with repeat
def test_generator_tuple_repeat_repeat_1():
"""
test generator tuple repeat repeat 1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
data1 = data1.repeat(3)
iter1 = data1.create_tuple_iterator(num_epochs=11)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
# still one more epoch left in the iter1.
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
# now iter1 has been exhausted, c++ pipeline has been shut down.
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_tuple_repeat_repeat_2():
"""
test generator tuple repeat repeat 2
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
data1 = data1.repeat(3)
iter1 = data1.create_tuple_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
# optional
iter1.stop()
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value)
def test_generator_tuple_repeat_repeat_3():
"""
test generator tuple repeat repeat 3
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
data1 = data1.repeat(3)
iter1 = data1.create_tuple_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
for _ in range(5):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
# rely on garbage collector to destroy iter1
def test_generator_reusedataset():
"""
test generator reusedataset
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
iter1 = data1.create_tuple_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2
data1 = data1.repeat(3)
iter1 = data1.create_tuple_iterator()
for _ in range(5):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
data1 = data1.batch(2)
iter1 = data1.create_dict_iterator()
for _ in range(5):
i = 0
sample = 0
for item in iter1: # each data is a dictionary
golden = np.array([[i % 64], [(i + 1) % 64]])
assert np.array_equal(item["data"], golden)
i = i + 2
sample = sample + 1
assert sample == 64 * 3
# rely on garbage collector to destroy iter1

View File

@ -87,7 +87,7 @@ def test_five_crop_error_msg():
data = data.map(input_columns=["image"], operations=transform()) data = data.map(input_columns=["image"], operations=transform())
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
data.create_tuple_iterator().get_next() data.create_tuple_iterator().__next__()
error_msg = "TypeError: img should be PIL Image or Numpy array. Got <class 'tuple'>" error_msg = "TypeError: img should be PIL Image or Numpy array. Got <class 'tuple'>"
# error msg comes from ToTensor() # error msg comes from ToTensor()

View File

@ -41,18 +41,18 @@ def test_case1():
assert data.get_batch_size() == 2 assert data.get_batch_size() == 2
assert data.get_repeat_count() == 1 assert data.get_repeat_count() == 1
data = data.repeat(10) data = data.repeat(10)
assert data.get_dataset_size() == 6 assert data.get_dataset_size() == 60
assert data.get_batch_size() == 2 assert data.get_batch_size() == 2
assert data.get_repeat_count() == 10 assert data.get_repeat_count() == 10
data = data.project(["new_column"]) data = data.project(["new_column"])
assert data.get_dataset_size() == 6 assert data.get_dataset_size() == 60
assert data.get_batch_size() == 2 assert data.get_batch_size() == 2
assert data.get_repeat_count() == 10 assert data.get_repeat_count() == 10
data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10) data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10)
data1 = data.zip(data2) data1 = data.zip(data2)
assert data1.get_dataset_size() == 6 assert data1.get_dataset_size() == 60
def test_case2(): def test_case2():
@ -65,14 +65,14 @@ def test_case2():
data = data.rename("col_sint64", "new_column") data = data.rename("col_sint64", "new_column")
assert data.get_dataset_size() == 3 assert data.get_dataset_size() == 3
data = data.repeat(10) data = data.repeat(10)
assert data.get_dataset_size() == 3 assert data.get_dataset_size() == 30
data = data.project(["new_column"]) data = data.project(["new_column"])
assert data.get_dataset_size() == 3 assert data.get_dataset_size() == 30
data2 = ds.TFRecordDataset(FILES, num_samples=6).batch(2).repeat(10) data2 = ds.TFRecordDataset(FILES, num_samples=6).batch(2).repeat(10)
data1 = data.zip(data2) data1 = data.zip(data2)
assert data1.get_dataset_size() == 3 assert data1.get_dataset_size() == 30
def test_case3(): def test_case3():
@ -94,11 +94,11 @@ def test_case4():
data2 = data2.shuffle(100) data2 = data2.shuffle(100)
assert data2.get_dataset_size() == 6 assert data2.get_dataset_size() == 6
data2 = data2.repeat(3) data2 = data2.repeat(3)
assert data2.get_dataset_size() == 6 assert data2.get_dataset_size() == 18
data3 = ds.zip((data1, data2)) data3 = ds.zip((data1, data2))
assert data3.get_dataset_size() == 6 assert data3.get_dataset_size() == 18
def test_case5(): def test_case5():

View File

@ -73,7 +73,7 @@ def test_iterator_weak_ref():
_cleanup() _cleanup()
with pytest.raises(AttributeError) as info: with pytest.raises(AttributeError) as info:
itr2.get_next() itr2.__next__()
assert "object has no attribute 'depipeline'" in str(info.value) assert "object has no attribute 'depipeline'" in str(info.value)
del itr1 del itr1

View File

@ -251,6 +251,49 @@ def test_nested_repeat11():
assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3 assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3
def test_repeat_count1():
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
data1_size = data1.get_dataset_size()
logger.info("dataset size is {}".format(data1_size))
batch_size = 2
repeat_count = 4
resize_height, resize_width = 32, 32
decode_op = vision.Decode()
resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR)
data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=resize_op)
data1 = data1.repeat(repeat_count)
data1 = data1.batch(batch_size, drop_remainder=False)
dataset_size = data1.get_dataset_size()
logger.info("dataset repeat then batch's size is {}".format(dataset_size))
num1_iter = 0
for _ in data1.create_dict_iterator():
num1_iter += 1
assert data1_size == 3
assert dataset_size == num1_iter == 6
def test_repeat_count2():
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
data1_size = data1.get_dataset_size()
logger.info("dataset size is {}".format(data1_size))
batch_size = 2
repeat_count = 4
resize_height, resize_width = 32, 32
decode_op = vision.Decode()
resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR)
data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=resize_op)
data1 = data1.batch(batch_size, drop_remainder=False)
data1 = data1.repeat(repeat_count)
dataset_size = data1.get_dataset_size()
logger.info("dataset batch then repeat's size is {}".format(dataset_size))
num1_iter = 0
for _ in data1.create_dict_iterator():
num1_iter += 1
assert data1_size == 3
assert dataset_size == num1_iter == 8
if __name__ == "__main__": if __name__ == "__main__":
test_tf_repeat_01() test_tf_repeat_01()
@ -268,3 +311,5 @@ if __name__ == "__main__":
test_nested_repeat9() test_nested_repeat9()
test_nested_repeat10() test_nested_repeat10()
test_nested_repeat11() test_nested_repeat11()
test_repeat_count1()
test_repeat_count2()

View File

@ -252,14 +252,14 @@ def test_zip_exception_06():
if __name__ == '__main__': if __name__ == '__main__':
test_zip_01() test_zip_01()
test_zip_02() #test_zip_02()
test_zip_03() #test_zip_03()
test_zip_04() #test_zip_04()
test_zip_05() #test_zip_05()
test_zip_06() #test_zip_06()
test_zip_exception_01() #test_zip_exception_01()
test_zip_exception_02() #test_zip_exception_02()
test_zip_exception_03() #test_zip_exception_03()
test_zip_exception_04() #test_zip_exception_04()
test_zip_exception_05() #test_zip_exception_05()
test_zip_exception_06() #test_zip_exception_06()

2770
tests/ut/python/log Normal file

File diff suppressed because it is too large Load Diff

View File

@ -274,6 +274,9 @@ class DatasetLenet():
def get_repeat_count(self): def get_repeat_count(self):
return 1 return 1
def create_tuple_iterator(self):
return self
def test_train_32k_8p(batch_size=32, num_classes=32768): def test_train_32k_8p(batch_size=32, num_classes=32768):
dev_num = 8 dev_num = 8

View File

@ -61,6 +61,9 @@ class DatasetLenet():
def get_repeat_count(self): def get_repeat_count(self):
return 1 return 1
def create_tuple_iterator(self):
return self
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):

View File

@ -58,6 +58,9 @@ class Dataset():
def get_repeat_count(self): def get_repeat_count(self):
return 1 return 1
def create_tuple_iterator(self):
return self
class GatherV2(_Loss): class GatherV2(_Loss):
def __init__(self, index_dim, strategy, index_size=16): def __init__(self, index_dim, strategy, index_size=16):

View File

@ -0,0 +1,107 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test dataset helper."""
import pytest
import numpy as np
import mindspore.context as context
from mindspore.communication.management import init
from mindspore.train.dataset_helper import DatasetHelper
from ....dataset_mock import MindData
def get_dataset(batch_size=1):
dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32)
dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1),
(batch_size, 20), (batch_size, 20), (batch_size, 20))
dataset = MindData(size=2, batch_size=batch_size, np_types=dataset_types,
output_shapes=dataset_shapes, input_indexs=(0, 1))
return dataset
def test_dataset_helper_dataset_sink_mode_str():
dataset = get_dataset(32)
with pytest.raises(TypeError):
DatasetHelper(dataset, dataset_sink_mode="True")
def test_dataset_helper_dataset_sink_mode_int():
dataset = get_dataset(32)
with pytest.raises(TypeError):
DatasetHelper(dataset, dataset_sink_mode=1)
def test_dataset_helper_sink_size_bool():
dataset = get_dataset(32)
with pytest.raises(TypeError):
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=True)
def test_dataset_helper_sink_size_float():
dataset = get_dataset(32)
with pytest.raises(TypeError):
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=1.0)
def test_dataset_helper_sink_size_negative():
dataset = get_dataset(32)
with pytest.raises(ValueError):
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=-2)
def test_dataset_iter_normal():
dataset = get_dataset(32)
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=False)
count = 0
for _ in range(2):
for _ in dataset_helper:
count += 1
dataset.reset()
assert count == 6
@pytest.mark.skipif('not context.get_context("enable_ge")')
def test_dataset_iter_ge():
init()
dataset = get_dataset(32)
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
count = 0
for _ in range(2):
for _ in dataset_helper:
count += 1
assert count == 2
@pytest.mark.skipif('context.get_context("enable_ge")')
def test_dataset_iter_ms_loop_sink():
init()
context.set_context(enable_loop_sink=True)
dataset = get_dataset(32)
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
count = 0
for _ in range(2):
for inputs in dataset_helper:
count += 1
assert inputs == tuple()
assert count == 2
@pytest.mark.skipif('context.get_context("enable_ge")')
def test_dataset_iter_ms():
init()
context.set_context(enable_loop_sink=False)
dataset = get_dataset(32)
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)