!3212 GetDatasize feature
Merge pull request !3212 from anzhengqi/epochs-ready
This commit is contained in:
commit
8e4c0a9d93
|
@ -25,6 +25,8 @@
|
|||
#include "minddata/dataset/engine/dataset_iterator.h"
|
||||
#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/filter_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||
|
@ -84,7 +86,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
|
|||
{kRandomData, &DEPipeline::ParseRandomDataOp},
|
||||
{kTextFile, &DEPipeline::ParseTextFileOp},
|
||||
{kBuildVocab, &DEPipeline::ParseBuildVocabOp},
|
||||
{kClue, &DEPipeline::ParseClueOp}};
|
||||
{kClue, &DEPipeline::ParseClueOp},
|
||||
{kEpochCtrl, &DEPipeline::ParseEpochCtrlOp}};
|
||||
|
||||
DEPipeline::DEPipeline() : iterator_(nullptr) {
|
||||
try {
|
||||
|
@ -166,8 +169,8 @@ Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &
|
|||
Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); }
|
||||
|
||||
// Function to launch the tree execution.
|
||||
Status DEPipeline::LaunchTreeExec() {
|
||||
RETURN_IF_NOT_OK(tree_->Prepare());
|
||||
Status DEPipeline::LaunchTreeExec(const int32_t num_epochs) {
|
||||
RETURN_IF_NOT_OK(tree_->Prepare(num_epochs));
|
||||
RETURN_IF_NOT_OK(tree_->Launch());
|
||||
iterator_ = std::make_unique<DatasetIterator>(tree_);
|
||||
if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator.");
|
||||
|
@ -252,6 +255,16 @@ int DEPipeline::GetRepeatCount() const { return repeat_num_; }
|
|||
|
||||
float ToFloat(const py::handle &handle) { return py::reinterpret_borrow<py::float_>(handle); }
|
||||
|
||||
Status DEPipeline::StopSend() {
|
||||
// tree_.root() must be DeviceQueueOp
|
||||
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_->root().get());
|
||||
if (op == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "StopSend only supported by DeviceQueueOp");
|
||||
}
|
||||
op->StopSend();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
|
||||
|
||||
bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }
|
||||
|
@ -804,6 +817,18 @@ Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
|
||||
std::shared_ptr<DatasetOp> *bottom) {
|
||||
if (args["count"].is_none()) {
|
||||
std::string err_msg = "Error: count is invalid or not set.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
std::shared_ptr<EpochCtrlOp> op;
|
||||
RETURN_IF_NOT_OK(EpochCtrlOp::Builder(ToInt(args["count"])).Build(&op));
|
||||
*top = op;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
|
||||
std::shared_ptr<DatasetOp> *bottom) {
|
||||
std::shared_ptr<GeneratorOp::Builder> builder = std::make_shared<GeneratorOp::Builder>();
|
||||
|
@ -973,8 +998,8 @@ Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<Data
|
|||
(void)builder->SetDeviceType(ToString(value));
|
||||
} else if (key == "device_id") {
|
||||
(void)builder->SetDeviceId(ToInt(value));
|
||||
} else if (key == "num_batch") {
|
||||
(void)builder->SetNumBatch(ToInt(value));
|
||||
} else if (key == "send_epoch_end") {
|
||||
(void)builder->SetSendEpochEnd(ToBool(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -70,7 +70,8 @@ enum OpName {
|
|||
kRandomData,
|
||||
kTextFile,
|
||||
kBuildVocab,
|
||||
kClue
|
||||
kClue,
|
||||
kEpochCtrl
|
||||
};
|
||||
|
||||
// The C++ binder class that we expose to the python script.
|
||||
|
@ -90,7 +91,7 @@ class DEPipeline {
|
|||
Status AssignRootNode(const DsOpPtr &dataset_op);
|
||||
|
||||
// Function to launch the tree execution.
|
||||
Status LaunchTreeExec();
|
||||
Status LaunchTreeExec(int32_t num_epochs);
|
||||
|
||||
// Get a row of data as dictionary of column name to the value.
|
||||
Status GetNextAsMap(py::dict *output);
|
||||
|
@ -143,6 +144,10 @@ class DEPipeline {
|
|||
Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
|
||||
std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
@ -189,6 +194,8 @@ class DEPipeline {
|
|||
|
||||
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status StopSend();
|
||||
|
||||
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
private:
|
||||
|
|
|
@ -159,7 +159,7 @@ void bindDEPipeline(py::module *m) {
|
|||
[](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
|
||||
.def("SetBatchParameters",
|
||||
[](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
|
||||
.def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); })
|
||||
.def("LaunchTreeExec", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.LaunchTreeExec(num_epochs)); })
|
||||
.def("GetNextAsMap",
|
||||
[](DEPipeline &de) {
|
||||
py::dict out;
|
||||
|
@ -188,6 +188,7 @@ void bindDEPipeline(py::module *m) {
|
|||
.def("GetBatchSize", &DEPipeline::GetBatchSize)
|
||||
.def("GetNumClasses", &DEPipeline::GetNumClasses)
|
||||
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
|
||||
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
|
||||
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
|
||||
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
|
||||
return true;
|
||||
|
@ -999,7 +1000,8 @@ PYBIND11_MODULE(_c_dataengine, m) {
|
|||
.value("BUILDVOCAB", OpName::kBuildVocab)
|
||||
.value("CELEBA", OpName::kCelebA)
|
||||
.value("TEXTFILE", OpName::kTextFile)
|
||||
.value("CLUE", OpName::kClue);
|
||||
.value("CLUE", OpName::kClue)
|
||||
.value("EPOCHCTRL", OpName::kEpochCtrl);
|
||||
|
||||
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
|
||||
.value("DE_JIEBA_MIX", JiebaMode::kMix)
|
||||
|
|
|
@ -40,7 +40,9 @@ Status IteratorBase::GetNextAsMap(TensorMap *out_map) {
|
|||
out_map->clear();
|
||||
|
||||
TensorRow curr_row;
|
||||
MS_LOG(INFO) << "get next as map start.";
|
||||
RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row));
|
||||
MS_LOG(INFO) << "fetchNextTensor success.";
|
||||
|
||||
// Return empty map if there's no data
|
||||
if (curr_row.empty()) {
|
||||
|
@ -105,7 +107,8 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
|
|||
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you
|
||||
// want to iterate again.
|
||||
if (eof_handled_) {
|
||||
return Status::OK();
|
||||
std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
|
||||
RETURN_STATUS_UNEXPECTED(err);
|
||||
}
|
||||
|
||||
// Check if we need to get a new DataBuffer to iterate.
|
||||
|
@ -119,36 +122,22 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
|
|||
// Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually
|
||||
// handle eoe and eof messages here.
|
||||
//
|
||||
// An eoe buffer means we have iterated fully to the end of the tree.
|
||||
// An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of
|
||||
// all operators.
|
||||
// An eoe buffer means we have iterated an epoch.
|
||||
// The next buffer in the pipeline might be an EOF or a databuffer for next epoch
|
||||
if (curr_buffer_->eoe()) {
|
||||
MS_LOG(DEBUG) << "End of data iteration. Fetch eof and then return empty row.";
|
||||
|
||||
// Before returning the last empty vector, fetch the eof buffer which should be the last
|
||||
// buffer, and then free it.
|
||||
RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
|
||||
|
||||
if (!curr_buffer_->eof()) {
|
||||
RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!");
|
||||
}
|
||||
eof_handled_ = true;
|
||||
curr_buffer_.reset(); // explicitly free the eof buffer
|
||||
// Set tree to Finished state
|
||||
root_->Tree()->SetFinished();
|
||||
|
||||
MS_LOG(INFO) << "End of data iteration.";
|
||||
curr_buffer_.reset(); // explicitly free the eoe buffer
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// An eof buffer means it is the end of execution and all operators are shutting down.
|
||||
// Because there is no more data to return to the caller, this will change `eof_handled_` state and
|
||||
// returns status unexpected error.
|
||||
if (curr_buffer_->eof()) {
|
||||
// An eof by itself, without being preceded by an eoe, is possible if a repeat operator
|
||||
// exists below us in the stack. Repeat operator eats eoe's but eventually allows the
|
||||
// flow of an eof up the pipeline by itself.
|
||||
eof_handled_ = true;
|
||||
curr_buffer_.reset(); // explicitly free the eof buffer
|
||||
// Set tree to Finished state
|
||||
root_->Tree()->SetFinished();
|
||||
return Status::OK();
|
||||
std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
|
||||
RETURN_STATUS_UNEXPECTED(err);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -208,20 +197,24 @@ Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) {
|
|||
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you
|
||||
// want to iterate again.
|
||||
if (eof_handled_) {
|
||||
return Status::OK();
|
||||
std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
|
||||
RETURN_STATUS_UNEXPECTED(err);
|
||||
}
|
||||
|
||||
// Check if we need to get a new DataBuffer to iterate.
|
||||
if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
|
||||
// GetNextInput() depends on current_op's EoeReceived. So, EOE buffer might be already be handled and
|
||||
// this child iterator might not see EOE buffer.
|
||||
RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
|
||||
|
||||
// Unlike the DatasetIterator, this child iterator does not quit after eoe.
|
||||
// Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the
|
||||
// If an eoe is picked up here, we simply return an empty vector and it's up to the
|
||||
// caller to decide what it wants to do next.
|
||||
if (curr_buffer_->eoe()) {
|
||||
MS_LOG(DEBUG) << "Child iterator picked up EOE.";
|
||||
end_epoch_ = true;
|
||||
return Status::OK();
|
||||
} else {
|
||||
end_epoch_ = false;
|
||||
}
|
||||
|
||||
if (curr_buffer_->eof()) {
|
||||
|
|
|
@ -144,6 +144,9 @@ class ChildIterator : public IteratorBase {
|
|||
// @return The string to column id mapping.
|
||||
std::unordered_map<std::string, int32_t> GetColumnNameMap() const override;
|
||||
|
||||
// Return T/F if end of epoch
|
||||
bool end_of_epoch() { return end_epoch_; }
|
||||
|
||||
private:
|
||||
DatasetOp *current_op_; // The parent operator. We consume from it's children.
|
||||
int32_t child_idx_; // The specific child this iterator will fetch from.
|
||||
|
|
|
@ -18,6 +18,7 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES
|
|||
shuffle_op.cc
|
||||
zip_op.cc
|
||||
concat_op.cc
|
||||
epoch_ctrl_op.cc
|
||||
cache_base_op.cc
|
||||
cache_lookup_op.cc
|
||||
cache_op.cc
|
||||
|
|
|
@ -17,11 +17,13 @@
|
|||
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -202,5 +204,29 @@ BuildVocabOp::Builder::Builder()
|
|||
builder_num_workers_ = cfg->num_parallel_workers();
|
||||
builder_connector_size_ = cfg->op_connector_size();
|
||||
}
|
||||
|
||||
// A print method typically used for debugging
|
||||
void BuildVocabOp::Print(std::ostream &out, bool show_all) const {
|
||||
// Always show the id and name as first line regardless if this summary or detailed print
|
||||
out << "(" << std::setw(2) << operator_id_ << ") <BuildVocabOp>:";
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nCode is needed here to show more info about the op."
|
||||
<< "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-Visitor accept method for NodePass
|
||||
Status BuildVocabOp::PreAccept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call the pre-visitation
|
||||
return p->PreRunOnNode(shared_from_base<BuildVocabOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -131,6 +131,21 @@ class BuildVocabOp : public ParallelOp {
|
|||
|
||||
~BuildVocabOp() = default;
|
||||
|
||||
/// \brief A print method typically used for debugging
|
||||
/// \param[out] out The output stream to write output to
|
||||
/// \param[in] show_all A bool to control if you want to show all info or just a summary
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// \briefStream output operator overload
|
||||
/// \notes This allows you to write the debug print info using stream operators
|
||||
/// \param[out] out Reference to the output stream being overloaded
|
||||
/// \param[in] vop - reference to the BuildVocabOp to display
|
||||
/// \return - the output stream must be returned
|
||||
friend std::ostream &operator<<(std::ostream &out, const BuildVocabOp &vop) {
|
||||
vop.Print(out, false);
|
||||
return out;
|
||||
}
|
||||
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// collect the work product from each worker
|
||||
|
@ -152,6 +167,12 @@ class BuildVocabOp : public ParallelOp {
|
|||
|
||||
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); }
|
||||
|
||||
/// \brief Base-class override for NodePass pre-visit acceptor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status PreAccept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
const int32_t interval_;
|
||||
bool special_first_;
|
||||
|
|
|
@ -96,7 +96,7 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
|
|||
RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id));
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr)));
|
||||
RETURN_IF_NOT_OK(EofReceived(worker_id));
|
||||
return Status::OK();
|
||||
}
|
||||
Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
|
||||
|
@ -298,5 +298,19 @@ Status CacheMergeOp::EoeReceived(int32_t worker_id) {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Base-class override for handling cases when an eof is received.
|
||||
Status CacheMergeOp::EofReceived(int32_t worker_id) {
|
||||
// If we are not in a repeated path, then the merge op gets a eof by itself, without first
|
||||
// getting an eoe. However, the logic demands that all epochs close with an eoe first before eof.
|
||||
// Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class
|
||||
// provides that for us.
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated)) {
|
||||
MS_LOG(DEBUG) << "Cache merge sending eoe";
|
||||
RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id));
|
||||
}
|
||||
MS_LOG(DEBUG) << "Cache merge sending eof";
|
||||
return DatasetOp::EofReceived(worker_id);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -176,6 +176,11 @@ class CacheMergeOp : public ParallelOp {
|
|||
/// \return Status object
|
||||
Status EoeReceived(int32_t worker_id) override;
|
||||
|
||||
/// \brief Base-class override for handling cases when an eof is received.
|
||||
/// \param worker_id - The worker id
|
||||
/// \return Status - The error code return
|
||||
Status EofReceived(int32_t worker_id) override;
|
||||
|
||||
protected:
|
||||
Status ComputeColMap() override;
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
#include "minddata/dataset/engine/data_buffer.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
@ -102,6 +103,15 @@ Status DatasetOp::InsertAsParent(std::shared_ptr<DatasetOp> to_add) {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Removes child operator in this operator.
|
||||
Status DatasetOp::RemoveChildren() {
|
||||
for (const auto &child : child_) {
|
||||
child->RemoveParent(this);
|
||||
}
|
||||
child_.clear();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Adds a parent operator to this operator
|
||||
void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); }
|
||||
|
@ -185,6 +195,12 @@ void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const {
|
|||
}
|
||||
}
|
||||
|
||||
// Getter function to get all of our children.
|
||||
std::vector<std::shared_ptr<DatasetOp>> DatasetOp::children() const { return child_; }
|
||||
|
||||
// Getter function to get all of our parents.
|
||||
std::vector<DatasetOp *> DatasetOp::parents() const { return parent_; }
|
||||
|
||||
// Creates the connector within this operator
|
||||
void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) {
|
||||
MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers
|
||||
|
|
|
@ -76,6 +76,9 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \return Status eerror code returned
|
||||
Status Remove();
|
||||
|
||||
// Removes child operator in this operator.
|
||||
Status RemoveChildren();
|
||||
|
||||
/// \brief Getter function to get a shared pointer to our child
|
||||
/// \param[in] child_index An operator can have n children. Indicates which child to return.
|
||||
/// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index
|
||||
|
@ -86,6 +89,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \param[in] parent_index An operator can have n parents. Indicates which parent to return.
|
||||
void Parent(DatasetOp **parent, int32_t parent_index) const;
|
||||
|
||||
// Getter function to get all of our children.
|
||||
std::vector<std::shared_ptr<DatasetOp>> children() const;
|
||||
|
||||
// Getter function to get all of our parents.
|
||||
std::vector<DatasetOp *> parents() const;
|
||||
|
||||
// Inserts a operator as the parent current op.
|
||||
// Inserted op will become the sole parent of the current op.
|
||||
// The existing parent of the current op will be transferred to the inserted op.
|
||||
|
|
|
@ -25,19 +25,21 @@
|
|||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/engine/perf/profiling.h"
|
||||
#include "minddata/dataset/engine/perf/device_queue_tracing.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
|
||||
int32_t op_connector_size, int64_t num_batch)
|
||||
int32_t op_connector_size, bool send_epoch_end)
|
||||
: PipelineOp(op_connector_size),
|
||||
channel_name_(channel_name),
|
||||
device_type_(device_type),
|
||||
device_id_(device_id),
|
||||
prefetch_size_(prefetch_size),
|
||||
num_batch_(num_batch) {}
|
||||
send_epoch_end_(send_epoch_end),
|
||||
stop_send_(false) {}
|
||||
|
||||
DeviceQueueOp::~DeviceQueueOp() {}
|
||||
|
||||
|
@ -53,8 +55,7 @@ DeviceQueueOp::Builder::Builder(int32_t prefetch_size)
|
|||
: builder_prefetch_size_(prefetch_size),
|
||||
builder_device_id_(0),
|
||||
builder_device_type_(DeviceType::CPU),
|
||||
builder_channel_name_(""),
|
||||
builder_num_batch_(0) {
|
||||
builder_channel_name_("") {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_op_connector_size_ = cfg->op_connector_size();
|
||||
}
|
||||
|
@ -64,6 +65,18 @@ Status DeviceQueueOp::EoeReceived(int32_t worker_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeviceQueueOp::CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const {
|
||||
// this method checks if the buffer meets the conditions to be sent to TDT
|
||||
if (buffer->NumRows() != 0) {
|
||||
TensorRow row;
|
||||
buffer->GetRow(0, &row);
|
||||
for (const auto &item : row) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device.");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeviceQueueOp::operator()() {
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
|
@ -82,23 +95,10 @@ Status DeviceQueueOp::operator()() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeviceQueueOp::CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const {
|
||||
// this method checks if the buffer meets the conditions to be sent to TDT
|
||||
if (buffer->NumRows() != 0) {
|
||||
TensorRow row;
|
||||
buffer->GetRow(0, &row);
|
||||
for (const auto &item : row) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device.");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TDTQUE
|
||||
Status DeviceQueueOp::SendDataToAscend() {
|
||||
MS_LOG(INFO) << "Device queue, sending data to Ascend.";
|
||||
int64_t total_batch = 0;
|
||||
bool is_break_loop = false;
|
||||
double batch_start_time, end_time;
|
||||
int32_t batch_cost, tdt_cost;
|
||||
int32_t connector_size = 0;
|
||||
|
@ -115,15 +115,20 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
std::unique_ptr<DataBuffer> current_buffer;
|
||||
RETURN_IF_NOT_OK(GetNextInput(¤t_buffer));
|
||||
|
||||
while (!current_buffer->eof() && !is_break_loop) {
|
||||
while (!current_buffer->eoe() && !is_break_loop) {
|
||||
while (!current_buffer->eof()) {
|
||||
while (!current_buffer->eoe()) {
|
||||
RETURN_IF_NOT_OK(CheckExceptions(current_buffer));
|
||||
TensorRow currRow;
|
||||
for (int row_id = 0; row_id < current_buffer->NumRows() && !is_break_loop; row_id++) {
|
||||
for (int row_id = 0; row_id < current_buffer->NumRows(); row_id++) {
|
||||
RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow));
|
||||
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost);
|
||||
if (status == TdtStatus::FAILED) {
|
||||
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
|
||||
if (stop_send_) {
|
||||
MS_LOG(INFO) << "stop_send received";
|
||||
return Status::OK();
|
||||
} else {
|
||||
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
|
||||
}
|
||||
}
|
||||
|
||||
if (isProfilingEnable) {
|
||||
|
@ -140,9 +145,6 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, total_batch + 1, connector_size);
|
||||
}
|
||||
total_batch++;
|
||||
if (num_batch_ > 0 && total_batch == num_batch_) {
|
||||
is_break_loop = true;
|
||||
}
|
||||
}
|
||||
if (isProfilingEnable) {
|
||||
connector_size = ChildOpConnectorSize();
|
||||
|
@ -150,6 +152,19 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
}
|
||||
RETURN_IF_NOT_OK(GetNextInput(¤t_buffer));
|
||||
}
|
||||
if (current_buffer->eoe() && send_epoch_end_) {
|
||||
TensorRow currRow;
|
||||
auto status =
|
||||
tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE);
|
||||
if (status == TdtStatus::FAILED) {
|
||||
if (stop_send_) {
|
||||
MS_LOG(INFO) << "stop_send received";
|
||||
return Status::OK();
|
||||
} else {
|
||||
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
if (isProfilingEnable) {
|
||||
connector_size = ChildOpConnectorSize();
|
||||
connector_capacity = ChildOpConnectorCapacity();
|
||||
|
@ -158,7 +173,7 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
}
|
||||
|
||||
tree_->SetFinished();
|
||||
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << ".";
|
||||
MS_LOG(INFO) << "Device queue total batch is " << total_batch;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -196,9 +211,6 @@ Status DeviceQueueOp::SendDataToGPU() {
|
|||
}
|
||||
RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle));
|
||||
total_batch++;
|
||||
if (num_batch_ > 0 && total_batch == num_batch_) {
|
||||
is_break_loop = true;
|
||||
}
|
||||
}
|
||||
if (!TaskManager::FindMe()->Interrupted())
|
||||
RETURN_IF_NOT_OK(GetNextInput(¤t_buffer));
|
||||
|
@ -211,12 +223,10 @@ Status DeviceQueueOp::SendDataToGPU() {
|
|||
is_break_loop = true;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << ".";
|
||||
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ".";
|
||||
|
||||
GpuBufferMgr::GetInstance().Close(handle);
|
||||
|
||||
GpuBufferMgr::GetInstance().CloseConfirm();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -240,8 +250,11 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con
|
|||
if (ret == BlockQueueStatus_T::ERROR_INPUT) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it.");
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Retry pushing data...";
|
||||
continue;
|
||||
if (!stop_send_) {
|
||||
MS_LOG(WARNING) << "Retry pushing data...";
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
|
@ -283,13 +296,11 @@ Status DeviceQueueOp::SendDataToCPU() {
|
|||
MS_LOG(DEBUG) << "Feature size is " << curr_row[0]->SizeInBytes() << ".";
|
||||
MS_LOG(DEBUG) << "Label size is " << curr_row[1]->SizeInBytes() << ".";
|
||||
total_batch++;
|
||||
if (num_batch_ > 0 && total_batch == num_batch_) {
|
||||
break;
|
||||
}
|
||||
if (stop_send_) break;
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << ".";
|
||||
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ".";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
#ifdef ENABLE_TDTQUE
|
||||
|
@ -84,8 +85,8 @@ class DeviceQueueOp : public PipelineOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetNumBatch(int64_t num_batch) {
|
||||
builder_num_batch_ = num_batch;
|
||||
Builder &SetSendEpochEnd(bool send_epoch_end) {
|
||||
builder_send_epoch_end_ = send_epoch_end;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -94,8 +95,9 @@ class DeviceQueueOp : public PipelineOp {
|
|||
// to call this Build() method. It will instantiate the DeviceQueueOp
|
||||
// and return it to caller as a shared pointer.
|
||||
Status Build(std::shared_ptr<DeviceQueueOp> *ptr) {
|
||||
*ptr = std::make_shared<DeviceQueueOp>(builder_channel_name_, builder_device_type_, builder_device_id_,
|
||||
builder_prefetch_size_, builder_op_connector_size_, builder_num_batch_);
|
||||
*ptr =
|
||||
std::make_shared<DeviceQueueOp>(builder_channel_name_, builder_device_type_, builder_device_id_,
|
||||
builder_prefetch_size_, builder_op_connector_size_, builder_send_epoch_end_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -104,14 +106,14 @@ class DeviceQueueOp : public PipelineOp {
|
|||
int32_t builder_device_id_;
|
||||
DeviceType builder_device_type_;
|
||||
std::string builder_channel_name_;
|
||||
int64_t builder_num_batch_;
|
||||
int32_t builder_op_connector_size_;
|
||||
bool builder_send_epoch_end_;
|
||||
};
|
||||
|
||||
// Name: constructor
|
||||
// Description
|
||||
DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
|
||||
int32_t op_connector_size, int64_t num_batch);
|
||||
int32_t op_connector_size, bool send_epoch_end);
|
||||
|
||||
// Name: destructor
|
||||
// Description
|
||||
|
@ -121,6 +123,8 @@ class DeviceQueueOp : public PipelineOp {
|
|||
|
||||
const int32_t get_prefetch_size() { return prefetch_size_; }
|
||||
|
||||
void StopSend() { stop_send_ = true; }
|
||||
|
||||
// Name: Print()
|
||||
// Description: A function that prints info about the node
|
||||
void Print(std::ostream &out, // In: The output stream to print to
|
||||
|
@ -149,6 +153,7 @@ class DeviceQueueOp : public PipelineOp {
|
|||
// Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp
|
||||
Status CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const;
|
||||
|
||||
private:
|
||||
#ifdef ENABLE_TDTQUE
|
||||
Status SendDataToAscend();
|
||||
#endif
|
||||
|
@ -164,7 +169,8 @@ class DeviceQueueOp : public PipelineOp {
|
|||
DeviceType device_type_;
|
||||
const int32_t device_id_;
|
||||
const int32_t prefetch_size_;
|
||||
const int64_t num_batch_;
|
||||
const bool send_epoch_end_;
|
||||
bool stop_send_;
|
||||
|
||||
#ifdef ENABLE_TDTQUE
|
||||
std::shared_ptr<TdtPlugin> tdtInstancePtr;
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
#include "minddata/dataset/engine/data_buffer.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
Status EpochCtrlOp::Builder::Build(std::shared_ptr<EpochCtrlOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<EpochCtrlOp>(build_max_repeats_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor
|
||||
EpochCtrlOp::EpochCtrlOp(int32_t num_epoch) : RepeatOp(num_epoch) { MS_LOG(INFO) << "Welcome to Epoch Ctrl Op."; }
|
||||
|
||||
// Destructor
|
||||
EpochCtrlOp::~EpochCtrlOp() {}
|
||||
|
||||
// A print method typically used for debugging
|
||||
void EpochCtrlOp::Print(std::ostream &out, bool show_all) const {
|
||||
// Always show the id and name as first line regardless if this summary or detailed print
|
||||
out << "(" << std::setw(2) << operator_id_ << ") <EpochCtrlOp>:";
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
PipelineOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << " [epochs: " << max_repeats_ << "]\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
PipelineOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << max_repeats_
|
||||
<< "\nLeaf Nodes in execution path:";
|
||||
if (!eoe_ops_.empty()) {
|
||||
for (size_t i = 0; i < eoe_ops_.size(); i++) {
|
||||
out << "\n Operator: " << eoe_ops_[i]->id();
|
||||
}
|
||||
} else {
|
||||
out << " None.";
|
||||
}
|
||||
out << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
Status EpochCtrlOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
|
||||
if (child_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("EpochCtrlOp can't be the leaf node.");
|
||||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> buf;
|
||||
|
||||
// `retry_if_eoe` is false because EpochCtrlOp does not eat EOE.
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, false));
|
||||
|
||||
// Only intercept EOE for EoeReceived processing, after that the EOE is forwarded to next op.
|
||||
// Other databuffers containing data or EOF will simply be forwarded.
|
||||
// EOF can simply be forwarded because this op does not spawn any thread, thus does not require clean up.
|
||||
if (buf->eoe()) {
|
||||
RETURN_IF_NOT_OK(EoeReceived(worker_id));
|
||||
}
|
||||
|
||||
*p_buffer = std::move(buf);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EpochCtrlOp::EoeReceived(int32_t worker_id) {
|
||||
repeat_count_++;
|
||||
MS_LOG(DEBUG) << "Epoch Control operator received end of epoch. Epoch count is now: " << repeat_count_
|
||||
<< ". Repeated: " << BitTest(op_ctrl_flags_, kDeOpRepeated) << ". Max epochs: " << max_repeats_;
|
||||
|
||||
// If we've reached the requested epoch count, then flag the leaf nodes
|
||||
// to tell them they've got one more epoch to perform. When they reach the end
|
||||
// of the last epoch, they quit rather than loop again.
|
||||
if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) {
|
||||
for (auto &eoe_op : eoe_ops_) {
|
||||
MS_LOG(DEBUG) << "EpochCtrl setting last repeat for eoe_op: " << eoe_op->id();
|
||||
eoe_op->set_control_flag(kDeOpLastRepeat);
|
||||
}
|
||||
}
|
||||
|
||||
// This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it.
|
||||
state_ = OpState::kDeOpIdle;
|
||||
|
||||
if (repeat_count_ != max_repeats_) {
|
||||
for (auto &eoe_op : eoe_ops_) {
|
||||
MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id();
|
||||
RETURN_IF_NOT_OK(eoe_op->Reset());
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pre-Visitor accept method for NodePass
|
||||
Status EpochCtrlOp::PreAccept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call the pre-visitation
|
||||
return p->PreRunOnNode(shared_from_base<EpochCtrlOp>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status EpochCtrlOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call the pre-visitation
|
||||
return p->RunOnNode(shared_from_base<EpochCtrlOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class EpochCtrlOp : public RepeatOp {
|
||||
public:
|
||||
class Builder : public RepeatOp::Builder {
|
||||
public:
|
||||
// Builder constructor. Creates the builder object.
|
||||
// @note No default args
|
||||
// @param count - The number of repeats to do
|
||||
// @return This is a constructor.
|
||||
explicit Builder(int32_t count) : RepeatOp::Builder(count) {}
|
||||
|
||||
// Default destructor
|
||||
~Builder() = default;
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @return shared_ptr to the new EpochCtrlOp object
|
||||
Status Build(std::shared_ptr<EpochCtrlOp> *);
|
||||
};
|
||||
|
||||
// Contructor
|
||||
explicit EpochCtrlOp(int32_t num_epoch);
|
||||
|
||||
// Destructor
|
||||
~EpochCtrlOp();
|
||||
|
||||
// A print method typically used for debugging
|
||||
// @param out - The output stream to write output to
|
||||
// @param show_all - A bool to control if you want to show all info or just a summary
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// This function returns the buffer that is at the top of our output connector. The caller is
|
||||
// typically our parent node, when the parent is asking us to provide the next buffer of data.
|
||||
// Since EpochCtrlOp is derived from RepeatOp which is an inlined op, getting a buffer from us
|
||||
// will simply bounce you to get a buffer from our child.
|
||||
// Epoch Control Op does not eat the EOE, it will pass the EOE to the next op.
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
|
||||
|
||||
// Base-class override for handling cases when an eoe is received.
|
||||
// @param worker_id - The worker id
|
||||
Status EoeReceived(int32_t worker_id) override;
|
||||
|
||||
/// \brief Base-class override for NodePass pre-visit acceptor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status PreAccept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for NodePass visitor acceptor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_
|
|
@ -132,6 +132,7 @@ Status RepeatOp::EoeReceived(int32_t worker_id) {
|
|||
|
||||
// Invoke a reset against the eoe nodes only.
|
||||
for (auto &eoe_op : eoe_ops_) {
|
||||
MS_LOG(DEBUG) << "Repeat operator sending reset to operator: " << eoe_op->id();
|
||||
RETURN_IF_NOT_OK(eoe_op->Reset());
|
||||
}
|
||||
|
||||
|
@ -167,8 +168,9 @@ int32_t RepeatOp::num_consumers() const {
|
|||
Status RepeatOp::Reset() {
|
||||
// If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op.
|
||||
// In that case, we now have to bounce the reset down to our own eoe ops.
|
||||
MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") reset.";
|
||||
MS_LOG(DEBUG) << "Repeat operator " << operator_id_ << " got reset.";
|
||||
for (auto &eoe_op : eoe_ops_) {
|
||||
MS_LOG(DEBUG) << "Nested repeat operator bouncing a reset to operator: " << eoe_op->id();
|
||||
RETURN_IF_NOT_OK(eoe_op->Reset());
|
||||
}
|
||||
state_ = OpState::kDeOpRunning;
|
||||
|
|
|
@ -46,7 +46,7 @@ class RepeatOp : public PipelineOp {
|
|||
// @return shared_ptr to the new RepeatOp object
|
||||
Status Build(std::shared_ptr<RepeatOp> *);
|
||||
|
||||
private:
|
||||
protected:
|
||||
int32_t build_max_repeats_;
|
||||
|
||||
Status SanityCheck() const;
|
||||
|
@ -131,11 +131,11 @@ class RepeatOp : public PipelineOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "RepeatOp"; }
|
||||
|
||||
/// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
|
||||
/// \param[in] eoe_op The input leaf/eoe operator to add to the list
|
||||
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
|
||||
// \param[in] eoe_op The input leaf/eoe operator to add to the list
|
||||
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }
|
||||
|
||||
private:
|
||||
protected:
|
||||
int32_t max_repeats_; // The number of repeats that the user requested
|
||||
int32_t repeat_count_; // A counter for the current number of executed repeats
|
||||
std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat.
|
||||
|
|
|
@ -132,8 +132,9 @@ Status ZipOp::prepare(TensorQTable *const table) {
|
|||
if (eof_) {
|
||||
return Status::OK();
|
||||
}
|
||||
// One of our child iterators encounter EOE. Returns and proceed with draining phase.
|
||||
if (new_row.empty()) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase got empty row!");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pack this first row into our tensor table
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
|
||||
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
|
||||
#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
|
||||
#include "minddata/dataset/engine/perf/profiling.h"
|
||||
#include "minddata/dataset/engine/perf/monitor.h"
|
||||
|
@ -50,11 +51,11 @@ Status ExecutionTree::AssociateNode(const std::shared_ptr<DatasetOp> &op) {
|
|||
if (op->tree_ == this) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) {
|
||||
if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding && tree_state_ != kDeTStatePrepare) {
|
||||
std::string err_msg =
|
||||
"Invalid tree state for adding a node. Current state: " + std::to_string(static_cast<int>(tree_state_)) +
|
||||
" Expected states: " + std::to_string(static_cast<int>(kDeTStateInit)) + " or " +
|
||||
std::to_string(static_cast<int>(kDeTStateBuilding));
|
||||
std::to_string(static_cast<int>(kDeTStateBuilding)) + " or " + std::to_string(static_cast<int>(kDeTStatePrepare));
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
|
@ -200,7 +201,9 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
|
|||
// For example, repeatOp inlining
|
||||
//
|
||||
// @return Status - The error code return
|
||||
Status ExecutionTree::Prepare() {
|
||||
Status ExecutionTree::Prepare(int32_t num_epochs) {
|
||||
num_epochs_ = num_epochs;
|
||||
|
||||
// Pre optimization compulsory transformation
|
||||
RETURN_IF_NOT_OK(this->PrepareTreePreAction());
|
||||
|
||||
|
@ -222,6 +225,7 @@ Status ExecutionTree::PrepareTreePreAction() {
|
|||
std::vector<std::unique_ptr<Pass>> pre_actions;
|
||||
// Construct pre actions
|
||||
MS_LOG(INFO) << "Running pre pass loops.";
|
||||
pre_actions.push_back(std::make_unique<InjectionPass>());
|
||||
pre_actions.push_back(std::make_unique<RemovalPass>());
|
||||
pre_actions.push_back(std::make_unique<CacheTransformPass>());
|
||||
// Apply pre action passes
|
||||
|
@ -278,6 +282,11 @@ Status ExecutionTree::PrepareDeprecated() {
|
|||
" Expected state: " + std::to_string(static_cast<int>(kDeTStatePrepare));
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
if (root_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Please assign one operator as the root of this tree.");
|
||||
}
|
||||
|
||||
// Start the recursive prepare
|
||||
RETURN_IF_NOT_OK(this->PrepareNode(root_));
|
||||
tree_state_ = kDeTStateReady;
|
||||
|
|
|
@ -176,7 +176,7 @@ class ExecutionTree {
|
|||
// For example, repeatOp inlining
|
||||
//
|
||||
// @return Status - The error code return
|
||||
Status Prepare();
|
||||
Status Prepare(int num_epochs = -1);
|
||||
|
||||
// Compulsory transformation/action pre optimization.
|
||||
// @return Status - The error code return
|
||||
|
@ -193,6 +193,7 @@ class ExecutionTree {
|
|||
// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
|
||||
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
|
||||
// it ready for execution.
|
||||
// @param Total number of epochs that will be run on this tree
|
||||
// @return Status - The error code return
|
||||
Status PrepareDeprecated();
|
||||
|
||||
|
@ -231,6 +232,10 @@ class ExecutionTree {
|
|||
// Optional optimizations status
|
||||
bool OptimizationEnabled() const { return optimize_; }
|
||||
|
||||
// Getter function to get the total number of epochs to be run on this tree.
|
||||
// @return total number of epochs
|
||||
int32_t num_epochs() { return num_epochs_; }
|
||||
|
||||
private:
|
||||
// A helper functions for doing the recursive printing
|
||||
// @param dataset_op - The dataset op to print
|
||||
|
@ -245,6 +250,7 @@ class ExecutionTree {
|
|||
int32_t id_count_; // Counter for generating operator id's
|
||||
uint32_t prepare_flags_; // Flags used during tree prepare
|
||||
TreeState tree_state_; // Tracking the current tree state
|
||||
int32_t num_epochs_; // Total number of epochs to run for this tree
|
||||
std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor
|
||||
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
|
||||
bool optimize_; // Flag to enable optional optimizations
|
||||
|
|
|
@ -5,6 +5,7 @@ add_library(engine-opt OBJECT
|
|||
post/repeat_pass.cc
|
||||
pre/cache_pass.cc
|
||||
pre/cache_transform_pass.cc
|
||||
pre/injection_pass.cc
|
||||
pre/removal_nodes.cc
|
||||
pre/removal_pass.cc
|
||||
optional/tensor_op_fusion_pass.cc
|
||||
|
|
|
@ -16,11 +16,13 @@
|
|||
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/engine/datasetops/batch_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/project_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/rename_op.h"
|
||||
|
@ -230,6 +232,11 @@ Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified)
|
|||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
|
@ -244,5 +251,15 @@ Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified
|
|||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -77,6 +77,10 @@ class CacheMergeOp;
|
|||
|
||||
class CacheLookupOp;
|
||||
|
||||
class EpochCtrlOp;
|
||||
|
||||
class BuildVocabOp;
|
||||
|
||||
// The base class Pass is the basic unit of tree transformation.
|
||||
// The actual implementation of the passes will be derived from here.
|
||||
class Pass : public std::enable_shared_from_this<Pass> {
|
||||
|
@ -190,12 +194,18 @@ class NodePass : public Pass {
|
|||
|
||||
virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified);
|
||||
|
||||
private:
|
||||
// Helper function to perform DFS visit
|
||||
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -28,6 +29,9 @@ RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(fa
|
|||
|
||||
// Identifies the subtree below this node as being in a repeated path of the tree.
|
||||
Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
||||
// Create a new stack for eoe operators and push onto our stack of stacks.
|
||||
std::unique_ptr<eoe_op_stack> new_stack = std::make_unique<eoe_op_stack>();
|
||||
eoe_op_stacks_.push(std::move(new_stack));
|
||||
// If we are already repeated, then this is a nested repeat.
|
||||
if (is_repeated_) {
|
||||
nested_repeats_++;
|
||||
|
@ -36,6 +40,18 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified)
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Identifies the subtree below this node as being in a repeated path of the tree.
|
||||
Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
|
||||
// EpochCtrl is derived from RepeatOp. Generally it should do the identical setup
|
||||
// that RepeatOp does. However, epoch control is actually simpler because it can
|
||||
// only exist as the root node so it doesn't need all the nested code.
|
||||
// Create a new stack for eoe operators and push onto our stack of stacks.
|
||||
std::unique_ptr<eoe_op_stack> new_stack = std::make_unique<eoe_op_stack>();
|
||||
eoe_op_stacks_.push(std::move(new_stack));
|
||||
is_repeated_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Identifies the subtree below this node as being in a cache merge path
|
||||
Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
|
||||
// Turn on the flag that we're under a merge op
|
||||
|
@ -47,13 +63,24 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modifi
|
|||
Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
||||
// Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking
|
||||
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
|
||||
|
||||
while (leaf_op != nullptr) {
|
||||
node->AddToEoeList(leaf_op);
|
||||
leaf_op = PopFromEOEOpStack();
|
||||
}
|
||||
|
||||
// At this point, we are done with the save area stack. It's a unique pointer to an empty stack
|
||||
// at this time, so we can pop it to get rid of it.
|
||||
eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
|
||||
if (!current_stack->empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!");
|
||||
}
|
||||
eoe_op_stacks_.pop();
|
||||
|
||||
// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
|
||||
// and add it to the list of eoe/leaf ops for the repeat, removing it from the save area.
|
||||
// and add it to the list of eoe/leaf ops for the repeat. It is important that the op is removed
|
||||
// from the save area, because the merge op above us may also take action on it later for a different
|
||||
// case when there is no repeat in the merge leg.
|
||||
if (is_merge_ && cache_lookup_) {
|
||||
cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||
node->AddToEoeList(std::move(cache_lookup_));
|
||||
|
@ -65,16 +92,29 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|||
node->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||
AddToEOEOpStack(node);
|
||||
nested_repeats_--;
|
||||
}
|
||||
|
||||
// If we are not nested, or we were the top-most repeat, now we clear the flag
|
||||
if (nested_repeats_ == 0) {
|
||||
} else {
|
||||
// If we are not nested, or we were the top-most repeat, now we clear the flag
|
||||
if (nested_repeats_ != 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!");
|
||||
}
|
||||
is_repeated_ = false;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Hooks up any identified eoe nodes under this repeat.
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
|
||||
// Pop the leaf ops from the save-area stack and add them to the eoe node tracking
|
||||
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
|
||||
while (leaf_op != nullptr) {
|
||||
node->AddToEoeList(leaf_op);
|
||||
leaf_op = PopFromEOEOpStack();
|
||||
}
|
||||
is_repeated_ = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// CacheOp removes previous leaf ops and replaces them with itself
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
if (is_repeated_) {
|
||||
|
@ -118,9 +158,16 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
|
|||
// Turns off the tracking for operations under merge op
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
|
||||
// Setting the flag is needed since we didn't call the base class DatasetOp version
|
||||
if (is_repeated_) node->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||
if (is_repeated_) {
|
||||
node->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||
// If there was not any repeat in the merge cache miss leg, then the cache_lookup
|
||||
// would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack
|
||||
if (cache_lookup_) {
|
||||
AddToEOEOpStack(std::move(cache_lookup_));
|
||||
}
|
||||
}
|
||||
cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used
|
||||
is_merge_ = false;
|
||||
cache_lookup_.reset(); // If a repeat op did not consume this then it's no longer needed
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -135,25 +182,32 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified
|
|||
// In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here.
|
||||
if (is_repeated_) {
|
||||
node->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||
AddToEOEOpStack(node);
|
||||
} else {
|
||||
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
|
||||
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
|
||||
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
|
||||
cache_lookup_ = std::static_pointer_cast<DatasetOp>(node);
|
||||
// Delay the assigment of this leap to the eoe stack and allow the merge op processing to handle that.
|
||||
}
|
||||
|
||||
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
|
||||
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
|
||||
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
|
||||
// Further, if there's a repeat above the merge but no repeat in the cache miss leg, then the merge op will
|
||||
// add the lookup to the eoe stack
|
||||
cache_lookup_ = std::static_pointer_cast<DatasetOp>(node);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Adds an operator to the eoe operator stack save area
|
||||
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); }
|
||||
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) {
|
||||
eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
|
||||
current_stack->push(dataset_op);
|
||||
}
|
||||
|
||||
// Pops an operator from the eoe operator stack save area
|
||||
std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() {
|
||||
std::shared_ptr<DatasetOp> top_op = nullptr;
|
||||
if (!eoe_stack_.empty()) {
|
||||
top_op = eoe_stack_.top();
|
||||
eoe_stack_.pop();
|
||||
eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
|
||||
if (current_stack != nullptr && !current_stack->empty()) {
|
||||
top_op = current_stack->top();
|
||||
current_stack->pop();
|
||||
}
|
||||
return top_op;
|
||||
}
|
||||
|
|
|
@ -30,6 +30,8 @@ namespace dataset {
|
|||
/// to the eoe-producing (typically leaf) nodes underneath it.
|
||||
class RepeatPass : public NodePass {
|
||||
public:
|
||||
using eoe_op_stack = std::stack<std::shared_ptr<DatasetOp>>;
|
||||
|
||||
/// \brief Constructor
|
||||
RepeatPass();
|
||||
|
||||
|
@ -39,6 +41,12 @@ class RepeatPass : public NodePass {
|
|||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the subtree below this node as being in a cache merge path
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
|
@ -51,6 +59,12 @@ class RepeatPass : public NodePass {
|
|||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Hooks up any identified eoe nodes under this repeat.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override;
|
||||
|
||||
/// \brief CacheOp removes previous leaf ops and replaces them with itself
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
|
@ -86,11 +100,11 @@ class RepeatPass : public NodePass {
|
|||
/// \return shared_ptr to the popped operator
|
||||
std::shared_ptr<DatasetOp> PopFromEOEOpStack();
|
||||
|
||||
bool is_repeated_; // T/F if we are processing under a repeat
|
||||
bool is_merge_; // T/F if we are processing under a cache merge op
|
||||
int32_t nested_repeats_; // A counter for nested repeats
|
||||
std::stack<std::shared_ptr<DatasetOp>> eoe_stack_; // A save area for leaf/eoe ops
|
||||
std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op
|
||||
bool is_repeated_; // T/F if we are processing under a repeat
|
||||
bool is_merge_; // T/F if we are processing under a cache merge op
|
||||
int32_t nested_repeats_; // A counter for nested repeats
|
||||
std::stack<std::unique_ptr<eoe_op_stack>> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting)
|
||||
std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// constructor
|
||||
InjectionPass::InjectionFinder::InjectionFinder(InjectionPass *injection_pass) : injection_pass_(injection_pass) {}
|
||||
|
||||
// Performs finder work for BuildVocabOp that has special rules about epoch control injection
|
||||
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) {
|
||||
if (injection_pass_) {
|
||||
injection_pass_->epoch_ctrl_bypass_ = true;
|
||||
return Status::OK();
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
|
||||
}
|
||||
}
|
||||
|
||||
// Temporary code to prevent the injection of epoch control when cache op is present
|
||||
// Remove this code in cache op phase 2
|
||||
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
if (injection_pass_) {
|
||||
injection_pass_->epoch_ctrl_bypass_ = true;
|
||||
return Status::OK();
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
|
||||
}
|
||||
}
|
||||
|
||||
// constructor
|
||||
InjectionPass::InjectionPass() : epoch_ctrl_bypass_(false) {}
|
||||
|
||||
// Runs an injection pass to inject in operators needed at the pre pass stage
|
||||
Status InjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||
MS_LOG(INFO) << "Pre pass: Injection pass started.";
|
||||
|
||||
// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
|
||||
// The finder can make updates to the InjectionPass object.
|
||||
InjectionPass::InjectionFinder finder(this);
|
||||
finder.Run(tree, modified);
|
||||
|
||||
// The first injection logic is to check if we should inject the epoch control op as the root node.
|
||||
// Do not inject the op if the number of epochs is 1.
|
||||
int32_t num_epochs = tree->num_epochs();
|
||||
if (num_epochs != 1 && !epoch_ctrl_bypass_) {
|
||||
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op;
|
||||
RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op));
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op));
|
||||
std::shared_ptr<DatasetOp> node = tree->root();
|
||||
if (std::dynamic_pointer_cast<DeviceQueueOp>(node) == nullptr) {
|
||||
tree->root()->InsertAsParent(epoch_ctrl_op);
|
||||
} else {
|
||||
tree->root()->child(0)->InsertAsParent(epoch_ctrl_op);
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Pre pass: Injection pass complete.";
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
|
||||
#define DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class DatasetOp;
|
||||
|
||||
/// \class InjectionPass injection_pass.h
|
||||
/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
|
||||
/// parsing.
|
||||
class InjectionPass : public TreePass {
|
||||
/// \class InjectionFinder
|
||||
/// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for
|
||||
/// operators that need to be injected. It is run first by the main injection pass to find out what operators
|
||||
/// it may need to inject.
|
||||
class InjectionFinder : public NodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit InjectionFinder(InjectionPass *injection_pass);
|
||||
|
||||
/// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Temporary code to prevent the injection of epoch control when cache op is present.
|
||||
/// Remove this code in cache op phase 2
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
private:
|
||||
InjectionPass *injection_pass_;
|
||||
};
|
||||
|
||||
public:
|
||||
/// \brief Constructor
|
||||
InjectionPass();
|
||||
|
||||
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The error code return
|
||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||
|
||||
private:
|
||||
bool epoch_ctrl_bypass_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
|
|
@ -29,20 +29,27 @@ std::shared_ptr<TdtPlugin> TdtPlugin::GetInstance() {
|
|||
return instance_ptr_;
|
||||
}
|
||||
|
||||
TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time) {
|
||||
TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time,
|
||||
tdt::TdtDataType tdt_type) {
|
||||
MS_LOG(DEBUG) << "TDT channel name is " << channel_name << ".";
|
||||
std::vector<DataItem> items;
|
||||
double start_time;
|
||||
auto ret = translate(ts_row, items);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "TDT converting tensor failed!";
|
||||
return FAILED;
|
||||
if (tdt_type == tdt::TDT_TENSOR) {
|
||||
auto ret = translate(ts_row, items);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "TDT converting tensor failed!";
|
||||
return FAILED;
|
||||
}
|
||||
} else if (tdt_type == tdt::TDT_END_OF_SEQUENCE) {
|
||||
DataItem data_item;
|
||||
data_item.dataType_ = tdt::TDT_END_OF_SEQUENCE;
|
||||
items.emplace_back(data_item);
|
||||
MS_LOG(INFO) << "TDT data type is TDT_END_OF_SEQUENCE";
|
||||
}
|
||||
if (profiling) {
|
||||
start_time = ProfilingTime::GetCurMilliSecond();
|
||||
}
|
||||
if (tdt::TdtHostPushData(channel_name, items) != 0) {
|
||||
MS_LOG(ERROR) << "TDT pushing data failed!";
|
||||
return FAILED;
|
||||
}
|
||||
if (profiling) {
|
||||
|
@ -122,8 +129,8 @@ TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector<DataItem> &i
|
|||
data_item.dataPtr_ =
|
||||
std::shared_ptr<void>(reinterpret_cast<uchar *>(&(*ts->begin<uint8_t>())), [](const void *elem) {});
|
||||
items.emplace_back(data_item);
|
||||
MS_LOG(DEBUG) << "TDT data type is " << datatype << ", data shape is " << dataShapes << ", data length is "
|
||||
<< ts->Size() << ".";
|
||||
MS_LOG(INFO) << "TDT data type is TDT_TENSOR, tensor type is " << datatype << ", tensor shape is " << dataShapes
|
||||
<< ", data length is " << ts->Size() << ".";
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
|
|
@ -38,7 +38,8 @@ class TdtPlugin {
|
|||
public:
|
||||
static std::shared_ptr<TdtPlugin> GetInstance();
|
||||
|
||||
TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time);
|
||||
TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time,
|
||||
tdt::TdtDataType tdt_type = tdt::TDT_TENSOR);
|
||||
|
||||
private:
|
||||
TdtPlugin() {}
|
||||
|
|
|
@ -797,6 +797,9 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
|
|||
(void)InitBackend();
|
||||
}
|
||||
#endif
|
||||
if (iter_num == -1) {
|
||||
iter_num = INT32_MAX;
|
||||
}
|
||||
if (name == kMsConvert || name == kMsVm) {
|
||||
return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run);
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
||||
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
||||
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32, check_save
|
||||
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
||||
try:
|
||||
|
@ -946,14 +946,14 @@ class Dataset:
|
|||
raise TypeError("apply_func must return a dataset.")
|
||||
return dataset
|
||||
|
||||
@check_positive_int32
|
||||
def device_que(self, prefetch_size=None):
|
||||
def device_que(self, prefetch_size=None, send_epoch_end=True):
|
||||
"""
|
||||
Return a transferredDataset that transfer data through device.
|
||||
|
||||
Args:
|
||||
prefetch_size (int, optional): prefetch number of records ahead of the
|
||||
user's request (default=None).
|
||||
send_epoch_end (bool, optional): whether send end of sequence to device or not.(default=True)
|
||||
|
||||
Note:
|
||||
If device is Ascend, features of data will be transferred one by one. The limitation
|
||||
|
@ -962,15 +962,14 @@ class Dataset:
|
|||
Return:
|
||||
TransferDataset, dataset for transferring.
|
||||
"""
|
||||
return self.to_device()
|
||||
return self.to_device(send_epoch_end=send_epoch_end)
|
||||
|
||||
@check_positive_int32
|
||||
def to_device(self, num_batch=None):
|
||||
def to_device(self, send_epoch_end=True):
|
||||
"""
|
||||
Transfer data through CPU, GPU or Ascend devices.
|
||||
|
||||
Args:
|
||||
num_batch (int, optional): limit the number of batch to be sent to device (default=None).
|
||||
send_epoch_end (bool, optional): whether send end of sequence to device or not.(default=True)
|
||||
|
||||
Note:
|
||||
If device is Ascend, features of data will be transferred one by one. The limitation
|
||||
|
@ -982,19 +981,9 @@ class Dataset:
|
|||
Raises:
|
||||
TypeError: If device_type is empty.
|
||||
ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'.
|
||||
ValueError: If num_batch is not positive or larger than int_max.
|
||||
ValueError: If dataset size is None or 0.
|
||||
RuntimeError: If dataset is unknown.
|
||||
RuntimeError: If distribution file path is given but failed to read.
|
||||
"""
|
||||
if self.get_dataset_size() is None or 0:
|
||||
raise ValueError("dataset size is None or 0.")
|
||||
|
||||
if num_batch is None:
|
||||
num_batch = self.get_dataset_size()
|
||||
repeat_count = self.get_repeat_count()
|
||||
num_batch = num_batch * repeat_count
|
||||
|
||||
queue_name = str(uuid.uuid1())
|
||||
|
||||
if context:
|
||||
|
@ -1008,9 +997,6 @@ class Dataset:
|
|||
if device_type not in ('Ascend', 'GPU', 'CPU'):
|
||||
raise ValueError("Only support CPU, Ascend, GPU")
|
||||
|
||||
if num_batch == 0:
|
||||
raise ValueError("num_batch is 0.")
|
||||
|
||||
def get_distribution(output_dataset):
|
||||
dev_id = 0
|
||||
if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2,
|
||||
|
@ -1032,7 +1018,7 @@ class Dataset:
|
|||
|
||||
distribution_path, device_id = get_distribution(self)
|
||||
if distribution_path == "":
|
||||
return TransferDataset(self, queue_name, device_id, device_type, num_batch)
|
||||
return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end)
|
||||
try:
|
||||
with open(distribution_path, 'r') as distribution_f:
|
||||
dist = json.load(distribution_f)
|
||||
|
@ -1042,7 +1028,7 @@ class Dataset:
|
|||
except Exception:
|
||||
raise RuntimeError("Distribution file failed to read")
|
||||
|
||||
return TransferDataset(self, queue_name, device_id, device_type, num_batch)
|
||||
return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end)
|
||||
|
||||
@check_save
|
||||
def save(self, file_name, num_files=1, file_type='mindrecord'):
|
||||
|
@ -1072,7 +1058,7 @@ class Dataset:
|
|||
|
||||
return SaveOp(self).save(file_names, file_type)
|
||||
|
||||
def create_tuple_iterator(self, columns=None):
|
||||
def create_tuple_iterator(self, columns=None, num_epochs=-1):
|
||||
"""
|
||||
Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data.
|
||||
|
||||
|
@ -1098,9 +1084,9 @@ class Dataset:
|
|||
"""
|
||||
if self._noop_mode():
|
||||
return DummyIterator(self, 'tuple')
|
||||
return TupleIterator(self, columns)
|
||||
return TupleIterator(self, columns, num_epochs)
|
||||
|
||||
def create_dict_iterator(self):
|
||||
def create_dict_iterator(self, num_epochs=-1):
|
||||
"""
|
||||
Create an Iterator over the dataset.
|
||||
|
||||
|
@ -1123,7 +1109,7 @@ class Dataset:
|
|||
"""
|
||||
if self._noop_mode():
|
||||
return DummyIterator(self, 'dict')
|
||||
return DictIterator(self)
|
||||
return DictIterator(self, num_epochs)
|
||||
|
||||
def __iter__(self):
|
||||
"""Create an Iterator over the dataset."""
|
||||
|
@ -1149,7 +1135,7 @@ class Dataset:
|
|||
self._batch_size = device_iter.get_batch_size()
|
||||
self._num_classes = device_iter.num_classes()
|
||||
self._repeat_count = device_iter.get_repeat_count()
|
||||
device_iter.release()
|
||||
device_iter.stop()
|
||||
|
||||
def output_shapes(self):
|
||||
"""
|
||||
|
@ -2085,7 +2071,7 @@ class RepeatDataset(DatasetOp):
|
|||
"""
|
||||
child_size = self.children[0].get_dataset_size()
|
||||
if child_size is not None:
|
||||
return child_size
|
||||
return child_size * self.count
|
||||
return None
|
||||
|
||||
def get_repeat_count(self):
|
||||
|
@ -2097,7 +2083,6 @@ class RepeatDataset(DatasetOp):
|
|||
"""
|
||||
return self.count
|
||||
|
||||
|
||||
class SkipDataset(DatasetOp):
|
||||
"""
|
||||
The result of applying Skip operator to the input Dataset.
|
||||
|
@ -2317,10 +2302,10 @@ class TransferDataset(DatasetOp):
|
|||
queue_name (str): Name of device queue.
|
||||
device_id (int): Id of device.
|
||||
device_type (str): Type of device, including "CPU", "GPU", and "Ascend".
|
||||
num_batch (int): limit the number of batch to be sent to device (default=None).
|
||||
send_epoch_end (bool, optional): Whether send end of sequence to device or not.(default=True)
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, queue_name, device_id, device_type, num_batch=None):
|
||||
def __init__(self, input_dataset, queue_name, device_id, device_type, send_epoch_end=True):
|
||||
super().__init__()
|
||||
self.children.append(input_dataset)
|
||||
input_dataset.parent.append(self)
|
||||
|
@ -2328,7 +2313,7 @@ class TransferDataset(DatasetOp):
|
|||
self._input_indexs = input_dataset.input_indexs
|
||||
self._device_type = device_type
|
||||
self._device_id = device_id
|
||||
self.__num_batch = num_batch
|
||||
self._send_epoch_end = send_epoch_end
|
||||
self.iterator = None
|
||||
|
||||
def get_args(self):
|
||||
|
@ -2336,13 +2321,13 @@ class TransferDataset(DatasetOp):
|
|||
args["queue_name"] = self.queue_name
|
||||
args["device_type"] = self._device_type
|
||||
args["device_id"] = self._device_id
|
||||
args["num_batch"] = self.__num_batch
|
||||
args["send_epoch_end"] = self._send_epoch_end
|
||||
return args
|
||||
|
||||
def create_dict_iterator(self):
|
||||
def create_dict_iterator(self, num_epochs=-1):
|
||||
raise RuntimeError("TransferDataset is not iterable")
|
||||
|
||||
def create_tuple_iterator(self, columns=None):
|
||||
def create_tuple_iterator(self, columns=None, num_epochs=-1):
|
||||
raise RuntimeError("TransferDataset is not iterable")
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -2354,12 +2339,14 @@ class TransferDataset(DatasetOp):
|
|||
def output_types(self):
|
||||
raise RuntimeError("TransferDataset does not support output_types")
|
||||
|
||||
def send(self):
|
||||
def send(self, num_epochs=-1):
|
||||
# need to keep iterator alive so the executionTree is not destroyed
|
||||
if self._noop_mode():
|
||||
return
|
||||
self.iterator = TupleIterator(self)
|
||||
self.iterator = TupleIterator(self, num_epochs=-1)
|
||||
|
||||
def stop_send(self):
|
||||
self.iterator.depipeline.StopSend()
|
||||
|
||||
class RangeDataset(MappableDataset):
|
||||
"""
|
||||
|
|
|
@ -29,7 +29,6 @@ from . import datasets as de
|
|||
|
||||
ITERATORS_LIST = list()
|
||||
|
||||
|
||||
def _cleanup():
|
||||
"""Release all the Iterator."""
|
||||
for itr_ref in ITERATORS_LIST:
|
||||
|
@ -60,7 +59,6 @@ def _alter_node(node):
|
|||
node.iterator_bootstrap()
|
||||
return node
|
||||
|
||||
|
||||
class Iterator:
|
||||
"""
|
||||
General Iterator over a dataset.
|
||||
|
@ -69,10 +67,21 @@ class Iterator:
|
|||
dataset: Dataset to be iterated over
|
||||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
def __init__(self, dataset, num_epochs=-1):
|
||||
self.num_epochs = num_epochs
|
||||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
# create a copy of tree and work on it.
|
||||
self.dataset = copy.deepcopy(dataset)
|
||||
self.parent_subtree = []
|
||||
|
||||
# The dataset passed into the iterator is not the root of the tree.
|
||||
# Trim the tree by saving the parent subtree into self.parent_subtree and
|
||||
# restore it after launching our c++ pipeline.
|
||||
if self.dataset.parent:
|
||||
logger.warning("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.")
|
||||
self.parent_subtree = self.dataset.parent
|
||||
self.dataset.parent = []
|
||||
|
||||
self.dataset = alter_tree(self.dataset)
|
||||
if not self.__is_tree():
|
||||
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
|
||||
|
@ -83,9 +92,17 @@ class Iterator:
|
|||
|
||||
root = self.__convert_node_postorder(self.dataset)
|
||||
self.depipeline.AssignRootNode(root)
|
||||
self.depipeline.LaunchTreeExec()
|
||||
self.depipeline.LaunchTreeExec(self.num_epochs)
|
||||
self._index = 0
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Manually terminate python iterator instead of relying on out of scope destruction.
|
||||
"""
|
||||
logger.info("terminating python iterator. This will also terminate c++ pipeline.")
|
||||
if hasattr(self, 'depipeline') and self.depipeline:
|
||||
del self.depipeline
|
||||
|
||||
def __is_tree_node(self, node):
|
||||
"""Check if a node is tree node."""
|
||||
if not node.children:
|
||||
|
@ -214,9 +231,14 @@ class Iterator:
|
|||
|
||||
@abstractmethod
|
||||
def get_next(self):
|
||||
pass
|
||||
raise RuntimeError("Calling base class Iterator's get_next is invalid.")
|
||||
|
||||
def __next__(self):
|
||||
if not self.depipeline:
|
||||
logger.warning("Iterator does not have a running c++ pipeline." +
|
||||
"It can be because Iterator stop() had been called, or c++ pipeline crashed silently.")
|
||||
raise RuntimeError("Iterator does not have a running c++ pipeline.")
|
||||
|
||||
data = self.get_next()
|
||||
if not data:
|
||||
if self._index == 0:
|
||||
|
@ -293,12 +315,12 @@ class TupleIterator(Iterator):
|
|||
def check_node_type(self, node):
|
||||
pass
|
||||
|
||||
def __init__(self, dataset, columns=None):
|
||||
def __init__(self, dataset, columns=None, num_epochs=-1):
|
||||
if columns is not None:
|
||||
if not isinstance(columns, list):
|
||||
columns = [columns]
|
||||
dataset = dataset.project(columns)
|
||||
super().__init__(dataset)
|
||||
super().__init__(dataset, num_epochs)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
|
|
@ -57,7 +57,8 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
|||
|
||||
# transform data format
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||
exec_dataset = exec_dataset.device_que()
|
||||
send_epoch_end = bool(dataset_size == -1)
|
||||
exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end)
|
||||
|
||||
_executor.init_dataset(exec_dataset.queue_name,
|
||||
dataset_size,
|
||||
|
@ -126,7 +127,7 @@ def _construct_tensor_list(types, shapes, batch_expand_num=1):
|
|||
|
||||
|
||||
def _to_tensor(elem, scaling_sens=None):
|
||||
"""Conver numpy to tensor, adapt to minddata feed solution."""
|
||||
"""Convert numpy to tensor, adapt to feed the data from host solution."""
|
||||
lst = []
|
||||
if not isinstance(elem, (tuple, list)):
|
||||
elem = [elem]
|
||||
|
@ -145,7 +146,8 @@ def _to_tensor(elem, scaling_sens=None):
|
|||
|
||||
|
||||
def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
|
||||
"""Conver numpy to tensor, expanding batch dimension according to device_num, adapt to minddata feed solution."""
|
||||
"""Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
|
||||
from host solution."""
|
||||
lst = []
|
||||
if not isinstance(elem, (tuple, list)):
|
||||
elem = [elem]
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import math
|
||||
import os
|
||||
|
||||
from mindspore._checkparam import check_bool
|
||||
from mindspore._checkparam import check_bool, check_int
|
||||
from .. import context
|
||||
from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
|
||||
_construct_tensor_list, _to_full_shapes, _to_full_tensor
|
||||
|
@ -42,17 +42,23 @@ class DatasetHelper:
|
|||
The iter of DatasetHelper will give one epoch data.
|
||||
|
||||
Args:
|
||||
dataset (DataSet): The dataset.
|
||||
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host.
|
||||
Default: True.
|
||||
dataset (DataSet): The training dataset iterator.
|
||||
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True.
|
||||
sink_size (int): Control the amount of data each sink.
|
||||
If sink_size=-1, sink the complete dataset each epoch.
|
||||
If sink_size>0, sink sink_size data each epoch. Default: -1.
|
||||
|
||||
Examples:
|
||||
>>> dataset_helper = DatasetHelper(dataset)
|
||||
>>> for inputs in dataset_helper:
|
||||
>>> outputs = network(*inputs)
|
||||
"""
|
||||
def __init__(self, dataset, dataset_sink_mode=True):
|
||||
|
||||
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1):
|
||||
check_bool(dataset_sink_mode)
|
||||
check_int(sink_size)
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
|
||||
if dataset_sink_mode:
|
||||
if context.get_context("enable_ge"):
|
||||
|
@ -68,9 +74,10 @@ class DatasetHelper:
|
|||
iterclass = _DatasetIterMS
|
||||
elif context.get_context("device_target") == "CPU":
|
||||
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
|
||||
self.iter = iterclass(dataset, sink_size)
|
||||
else:
|
||||
iterclass = _DatasetIterFeed
|
||||
self.iter = iterclass(dataset)
|
||||
iterclass = _DatasetIterNormal
|
||||
self.iter = iterclass(dataset)
|
||||
|
||||
def __iter__(self):
|
||||
return self.iter.__iter__()
|
||||
|
@ -80,21 +87,26 @@ class DatasetHelper:
|
|||
"""Get the types and shapes from dataset on current config."""
|
||||
return self.iter.types_shapes()
|
||||
|
||||
def loop_size(self):
|
||||
"""Get loop_size for every iteration."""
|
||||
return self.iter.loop_size
|
||||
def sink_size(self):
|
||||
"""Get sink_size for every iteration."""
|
||||
return self.iter.get_sink_size()
|
||||
|
||||
def stop_send(self):
|
||||
"""Free up resources about data sink."""
|
||||
self.iter.stop_send()
|
||||
|
||||
|
||||
class _DatasetIter:
|
||||
"""Base iter for dataset help"""
|
||||
def __init__(self, dataset):
|
||||
if not hasattr(dataset, '__loop_size__'):
|
||||
self.loop_size = dataset.get_dataset_size()
|
||||
else:
|
||||
self.loop_size = dataset.__loop_size__
|
||||
"""Base iter for dataset helper"""
|
||||
def __init__(self, dataset, sink_size):
|
||||
self.dataset = dataset
|
||||
self.sink_size = sink_size
|
||||
self.sink_count = 1
|
||||
|
||||
if not hasattr(dataset, '__ME_INITED__'):
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||
if not hasattr(dataset, '__TRANSFER_DATASET__'):
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
self.sink_size = dataset.__loop_size__
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
|
@ -102,43 +114,70 @@ class _DatasetIter:
|
|||
else:
|
||||
_send_data(dataset)
|
||||
|
||||
self.ind = 0
|
||||
self.dataset = dataset
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
|
||||
self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
|
||||
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
|
||||
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
||||
|
||||
def __iter__(self):
|
||||
self.ind = 0
|
||||
self.index = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.ind >= self.loop_count:
|
||||
if self.index >= self.sink_count:
|
||||
raise StopIteration()
|
||||
self.ind += 1
|
||||
self.index += 1
|
||||
return self.op()
|
||||
|
||||
def types_shapes(self):
|
||||
return self.dataset_types, self.dataset_shapes
|
||||
|
||||
def get_loop_count(self, dataset):
|
||||
loop_count = 1
|
||||
def get_sink_count(self, dataset):
|
||||
sink_count = 1
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
loop_size = dataset.__loop_size__
|
||||
if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
|
||||
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
|
||||
f'loop_size {loop_size} are not matched.')
|
||||
loop_count = math.ceil(dataset.get_dataset_size() / loop_size)
|
||||
return loop_count
|
||||
f'sink_size {loop_size} are not matched.')
|
||||
sink_count = math.ceil(dataset.get_dataset_size() / loop_size)
|
||||
return sink_count
|
||||
|
||||
def get_sink_size(self):
|
||||
"""get sink_size to device"""
|
||||
sink_size = 1
|
||||
if hasattr(self.dataset, '__loop_size__'):
|
||||
sink_size = self.dataset.__loop_size__
|
||||
else:
|
||||
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend":
|
||||
if self.sink_size > 0:
|
||||
sink_size = self.sink_size
|
||||
else:
|
||||
sink_size = self.dataset.get_dataset_size()
|
||||
return sink_size
|
||||
|
||||
|
||||
class _DatasetIterGE(_DatasetIter):
|
||||
"""Iter for GE."""
|
||||
def __init__(self, dataset, sink_size):
|
||||
super().__init__(dataset, sink_size)
|
||||
self.sink_count = self.get_sink_count(dataset)
|
||||
batch_expand_num = 1
|
||||
if _need_to_full():
|
||||
batch_expand_num = _get_device_num()
|
||||
tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num)
|
||||
|
||||
def op():
|
||||
return tensor_list_run
|
||||
|
||||
self.op = op
|
||||
|
||||
|
||||
class _DatasetIterMSLoopSink(_DatasetIter):
|
||||
"""Iter for context (device_target=Ascend)"""
|
||||
def __init__(self, dataset):
|
||||
super(_DatasetIterMSLoopSink, self).__init__(dataset)
|
||||
self.loop_count = self.get_loop_count(dataset)
|
||||
def __init__(self, dataset, sink_size):
|
||||
super().__init__(dataset, sink_size)
|
||||
self.sink_count = self.get_sink_count(dataset)
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
self.loop_count = 1
|
||||
self.sink_count = 1
|
||||
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
|
||||
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
|
||||
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
|
||||
|
@ -153,66 +192,42 @@ class _DatasetIterMSLoopSink(_DatasetIter):
|
|||
|
||||
|
||||
class _DatasetIterMS(_DatasetIter):
|
||||
"""Iter for context (device_target=GPU)"""
|
||||
def __init__(self, dataset):
|
||||
super(_DatasetIterMS, self).__init__(dataset)
|
||||
self.loop_count = dataset.get_dataset_size()
|
||||
self.loop_size = 1
|
||||
"""Iter for MS(enable_loop_sink=False)."""
|
||||
def __init__(self, dataset, sink_size):
|
||||
super().__init__(dataset, sink_size)
|
||||
if sink_size > 0:
|
||||
self.sink_count = sink_size
|
||||
else:
|
||||
self.sink_count = dataset.get_dataset_size()
|
||||
|
||||
queue_name = dataset.__ME_INITED__
|
||||
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
|
||||
|
||||
|
||||
class _DatasetIterPSLite(_DatasetIter):
|
||||
"""Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED"""
|
||||
def __init__(self, dataset):
|
||||
super(_DatasetIterPSLite, self).__init__(dataset)
|
||||
self.loop_count = 1
|
||||
self.loop_size = 1
|
||||
def __init__(self, dataset, sink_size):
|
||||
super().__init__(dataset, sink_size)
|
||||
self.sink_count = 1
|
||||
self.sink_size = 1
|
||||
self.op = None
|
||||
def op():
|
||||
return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1)
|
||||
self.op = op
|
||||
|
||||
|
||||
class _DatasetIterGE(_DatasetIter):
|
||||
"""Iter for ge"""
|
||||
def __init__(self, dataset):
|
||||
super(_DatasetIterGE, self).__init__(dataset)
|
||||
self.loop_count = self.get_loop_count(dataset)
|
||||
batch_expand_num = 1
|
||||
if _need_to_full():
|
||||
batch_expand_num = _get_device_num()
|
||||
tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num)
|
||||
|
||||
def op():
|
||||
return tensor_list_run
|
||||
|
||||
self.op = op
|
||||
|
||||
|
||||
class _DatasetIterFeed:
|
||||
class _DatasetIterNormal:
|
||||
"""Iter for normal(non sink) mode, feed the data from host."""
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
self.device_num = _get_device_num()
|
||||
self.global_rank = _get_global_rank()
|
||||
self.repeat_count = dataset.get_repeat_count()
|
||||
self.repeat_ind = 0
|
||||
self.loop_count = dataset.get_dataset_size()
|
||||
self.ind = 0
|
||||
|
||||
def __iter__(self):
|
||||
if self.repeat_ind % self.repeat_count == 0:
|
||||
self.iter = self.dataset.__iter__()
|
||||
|
||||
self.repeat_ind += 1
|
||||
self.ind = 0
|
||||
self.iter = self.dataset.create_tuple_iterator()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.ind >= self.loop_count:
|
||||
raise StopIteration()
|
||||
self.ind += 1
|
||||
data = self.iter.__next__()
|
||||
if _need_to_full():
|
||||
return _to_full_tensor(data, self.device_num, self.global_rank)
|
||||
|
|
|
@ -21,7 +21,7 @@ import numpy as np
|
|||
from mindspore import log as logger
|
||||
from ..common.tensor import Tensor
|
||||
from ..nn.metrics import get_metrics
|
||||
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
|
||||
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int
|
||||
from .callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from .. import context
|
||||
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
|
@ -225,7 +225,7 @@ class Model:
|
|||
scaling_sens /= self._device_number
|
||||
return scaling_sens
|
||||
|
||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode):
|
||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1):
|
||||
"""Initializes dataset."""
|
||||
need_wrap = False
|
||||
if dataset_sink_mode:
|
||||
|
@ -237,7 +237,7 @@ class Model:
|
|||
if not is_train:
|
||||
dataset.__loop_size__ = 1
|
||||
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode)
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size)
|
||||
|
||||
# remove later to deal with loop sink
|
||||
if need_wrap:
|
||||
|
@ -317,7 +317,7 @@ class Model:
|
|||
self._eval_network.compile(*inputs)
|
||||
break
|
||||
|
||||
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
|
||||
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
|
||||
"""
|
||||
Training.
|
||||
|
||||
|
@ -332,6 +332,7 @@ class Model:
|
|||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
Configure pynative mode, the training process will be performed with
|
||||
dataset not sink.
|
||||
sink_size (int): Control the amount of data each sink. Default: -1.
|
||||
"""
|
||||
epoch = check_int_positive(epoch)
|
||||
self._train_network.set_train()
|
||||
|
@ -342,7 +343,10 @@ class Model:
|
|||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.epoch_num = epoch
|
||||
cb_params.batch_num = train_dataset.get_dataset_size()
|
||||
if dataset_sink_mode and sink_size > 0:
|
||||
cb_params.batch_num = sink_size
|
||||
else:
|
||||
cb_params.batch_num = train_dataset.get_dataset_size()
|
||||
cb_params.mode = "train"
|
||||
cb_params.loss_fn = self._loss_fn
|
||||
cb_params.optimizer = self._optimizer
|
||||
|
@ -364,7 +368,7 @@ class Model:
|
|||
"So the training process will be performed with dataset not sink.")
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
else:
|
||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size)
|
||||
|
||||
@staticmethod
|
||||
def _transform_callbacks(callbacks):
|
||||
|
@ -377,7 +381,7 @@ class Model:
|
|||
|
||||
return [callbacks]
|
||||
|
||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
|
||||
"""
|
||||
Training process. The data would be passed to network through dataset channel.
|
||||
|
||||
|
@ -390,17 +394,18 @@ class Model:
|
|||
function respectively.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
sink_size (int): Control the amount of data each sink. Default: -1.
|
||||
"""
|
||||
dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
dataset=train_dataset,
|
||||
dataset_sink_mode=True)
|
||||
dataset_sink_mode=True,
|
||||
sink_size=sink_size)
|
||||
self._train_network = train_network
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.cur_step_num = 0
|
||||
|
||||
loop_size = dataset_helper.loop_size()
|
||||
run_context = RunContext(cb_params)
|
||||
list_callback.begin(run_context)
|
||||
|
||||
|
@ -412,9 +417,9 @@ class Model:
|
|||
|
||||
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
||||
for inputs in dataset_helper:
|
||||
cb_params.cur_step_num += loop_size
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._train_network(*inputs)
|
||||
cb_params.cur_step_num += dataset_helper.sink_size()
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
|
||||
|
@ -422,6 +427,7 @@ class Model:
|
|||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
dataset_helper.stop_send()
|
||||
|
||||
list_callback.end(run_context)
|
||||
|
||||
|
@ -490,7 +496,7 @@ class Model:
|
|||
|
||||
list_callback.end(run_context)
|
||||
|
||||
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
|
||||
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
|
||||
"""
|
||||
Training API where the iteration is controlled by python front-end.
|
||||
|
||||
|
@ -515,7 +521,10 @@ class Model:
|
|||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
Configure pynative mode, the training process will be performed with
|
||||
dataset not sink.
|
||||
|
||||
sink_size (int): Control the amount of data each sink.
|
||||
If sink_size=-1, sink the complete dataset each epoch.
|
||||
If sink_size>0, sink sink_size data each epoch.
|
||||
If dataset_sink_mode is False, set sink_size invalid. Default: -1.
|
||||
|
||||
Examples:
|
||||
>>> dataset = get_dataset()
|
||||
|
@ -526,17 +535,19 @@ class Model:
|
|||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
|
||||
>>> model.train(2, dataset)
|
||||
"""
|
||||
repeat_count = train_dataset.get_repeat_count()
|
||||
if epoch != repeat_count and dataset_sink_mode is True:
|
||||
logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}")
|
||||
check_bool(dataset_sink_mode)
|
||||
check_int(sink_size)
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
||||
|
||||
self._train(epoch,
|
||||
train_dataset,
|
||||
callbacks=callbacks,
|
||||
dataset_sink_mode=dataset_sink_mode)
|
||||
dataset_sink_mode=dataset_sink_mode,
|
||||
sink_size=sink_size)
|
||||
|
||||
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
|
|
|
@ -43,7 +43,7 @@ if __name__ == "__main__":
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
||||
ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, cfg.epoch_size)
|
||||
ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, 1)
|
||||
network = AlexNet(cfg.num_classes)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size()))
|
||||
|
|
|
@ -57,7 +57,7 @@ if __name__ == '__main__':
|
|||
|
||||
ds_train = create_dataset(args_opt.dataset_path,
|
||||
train_mode=True,
|
||||
epochs=train_config.train_epochs,
|
||||
epochs=1,
|
||||
batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format),
|
||||
rank_size=rank_size,
|
||||
|
@ -82,7 +82,7 @@ if __name__ == '__main__':
|
|||
|
||||
if args_opt.do_eval:
|
||||
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
|
||||
epochs=train_config.train_epochs,
|
||||
epochs=1,
|
||||
batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format))
|
||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric,
|
||||
|
|
|
@ -66,7 +66,7 @@ if __name__ == "__main__":
|
|||
init()
|
||||
args_opt.base_size = config.crop_size
|
||||
args_opt.crop_size = config.crop_size
|
||||
train_dataset = create_dataset(args_opt, args_opt.data_url, config.epoch_size, config.batch_size, usage="train")
|
||||
train_dataset = create_dataset(args_opt, args_opt.data_url, 1, config.batch_size, usage="train")
|
||||
dataset_size = train_dataset.get_dataset_size()
|
||||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
callback = [time_cb, LossCallBack()]
|
||||
|
|
|
@ -94,7 +94,7 @@ if __name__ == '__main__':
|
|||
loss_scale = float(config.loss_scale)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as FasterRcnn.mindrecord0.
|
||||
dataset = create_fasterrcnn_dataset(mindrecord_file, repeat_num=config.epoch_size,
|
||||
dataset = create_fasterrcnn_dataset(mindrecord_file, repeat_num=1,
|
||||
batch_size=config.batch_size, device_num=device_num, rank_id=rank)
|
||||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
|
|
|
@ -78,7 +78,7 @@ if __name__ == '__main__':
|
|||
mirror_mean=True)
|
||||
init()
|
||||
|
||||
dataset = create_dataset(cfg.data_path, cfg.epoch_size)
|
||||
dataset = create_dataset(cfg.data_path, 1)
|
||||
batch_num = dataset.get_dataset_size()
|
||||
|
||||
net = GoogleNet(num_classes=cfg.num_classes)
|
||||
|
|
|
@ -45,8 +45,7 @@ if __name__ == "__main__":
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
ds_train = create_dataset(os.path.join(args.data_path, "train"),
|
||||
cfg.batch_size,
|
||||
cfg.epoch_size)
|
||||
cfg.batch_size)
|
||||
|
||||
network = LeNet5(cfg.num_classes)
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
|
|
|
@ -44,7 +44,7 @@ args = parser.parse_args()
|
|||
|
||||
if __name__ == "__main__":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
|
||||
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1)
|
||||
step_size = ds_train.get_dataset_size()
|
||||
|
||||
# define fusion network
|
||||
|
|
|
@ -77,7 +77,7 @@ if __name__ == '__main__':
|
|||
model = Model(network, loss, opt, {'acc': Accuracy()})
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs)
|
||||
ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, 1)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck)
|
||||
|
|
|
@ -249,7 +249,7 @@ def train_parallel(config: TransformerConfig):
|
|||
|
||||
pre_train_dataset = load_dataset(
|
||||
data_files=config.pre_train_dataset,
|
||||
batch_size=config.batch_size, epoch_count=config.epochs,
|
||||
batch_size=config.batch_size, epoch_count=1,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
sink_step=config.dataset_sink_step,
|
||||
rank_size=MultiAscend.get_group_size(),
|
||||
|
@ -257,7 +257,7 @@ def train_parallel(config: TransformerConfig):
|
|||
) if config.pre_train_dataset else None
|
||||
fine_tune_dataset = load_dataset(
|
||||
data_files=config.fine_tune_dataset,
|
||||
batch_size=config.batch_size, epoch_count=config.epochs,
|
||||
batch_size=config.batch_size, epoch_count=1,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
sink_step=config.dataset_sink_step,
|
||||
rank_size=MultiAscend.get_group_size(),
|
||||
|
@ -265,7 +265,7 @@ def train_parallel(config: TransformerConfig):
|
|||
) if config.fine_tune_dataset else None
|
||||
test_dataset = load_dataset(
|
||||
data_files=config.test_dataset,
|
||||
batch_size=config.batch_size, epoch_count=config.epochs,
|
||||
batch_size=config.batch_size, epoch_count=1,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
sink_step=config.dataset_sink_step,
|
||||
rank_size=MultiAscend.get_group_size(),
|
||||
|
@ -288,17 +288,17 @@ def train_single(config: TransformerConfig):
|
|||
print(" | Starting training on single device.")
|
||||
pre_train_dataset = load_dataset(data_files=config.pre_train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
epoch_count=config.epochs,
|
||||
epoch_count=1,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
sink_step=config.dataset_sink_step) if config.pre_train_dataset else None
|
||||
fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset,
|
||||
batch_size=config.batch_size,
|
||||
epoch_count=config.epochs,
|
||||
epoch_count=1,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
sink_step=config.dataset_sink_step) if config.fine_tune_dataset else None
|
||||
test_dataset = load_dataset(data_files=config.test_dataset,
|
||||
batch_size=config.batch_size,
|
||||
epoch_count=config.epochs,
|
||||
epoch_count=1,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
sink_step=config.dataset_sink_step) if config.test_dataset else None
|
||||
|
||||
|
|
|
@ -180,7 +180,7 @@ if __name__ == '__main__':
|
|||
do_train=True,
|
||||
config=config_gpu,
|
||||
platform=args_opt.platform,
|
||||
repeat_num=epoch_size,
|
||||
repeat_num=1,
|
||||
batch_size=config_gpu.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# resume
|
||||
|
@ -239,7 +239,7 @@ if __name__ == '__main__':
|
|||
do_train=True,
|
||||
config=config_ascend,
|
||||
platform=args_opt.platform,
|
||||
repeat_num=epoch_size,
|
||||
repeat_num=1,
|
||||
batch_size=config_ascend.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
if args_opt.pre_trained:
|
||||
|
|
|
@ -86,7 +86,7 @@ if __name__ == '__main__':
|
|||
do_train=True,
|
||||
config=config,
|
||||
device_target=args_opt.device_target,
|
||||
repeat_num=epoch_size,
|
||||
repeat_num=1,
|
||||
batch_size=config.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# load pre trained ckpt
|
||||
|
|
|
@ -181,7 +181,7 @@ if __name__ == '__main__':
|
|||
do_train=True,
|
||||
config=config_gpu,
|
||||
platform=args_opt.platform,
|
||||
repeat_num=epoch_size,
|
||||
repeat_num=1,
|
||||
batch_size=config_gpu.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# resume
|
||||
|
@ -240,7 +240,7 @@ if __name__ == '__main__':
|
|||
do_train=True,
|
||||
config=config_ascend,
|
||||
platform=args_opt.platform,
|
||||
repeat_num=epoch_size,
|
||||
repeat_num=1,
|
||||
batch_size=config_ascend.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
if args_opt.pre_trained:
|
||||
|
|
|
@ -36,12 +36,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|||
|
||||
_cur_dir = os.getcwd()
|
||||
|
||||
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""):
|
||||
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
|
||||
""" do train """
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
|
||||
steps_per_epoch = dataset.get_dataset_size()
|
||||
epoch_num = dataset.get_repeat_count()
|
||||
# optimizer
|
||||
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
|
||||
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
|
||||
|
@ -176,11 +175,11 @@ def run_classifier():
|
|||
assessment_method=assessment_method)
|
||||
|
||||
if args_opt.do_train.lower() == "true":
|
||||
ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num,
|
||||
ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
|
||||
assessment_method=assessment_method,
|
||||
data_file_path=args_opt.train_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path)
|
||||
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path)
|
||||
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
if save_finetune_checkpoint_path == "":
|
||||
|
@ -191,7 +190,7 @@ def run_classifier():
|
|||
ds.get_dataset_size(), epoch_num, "classifier")
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num,
|
||||
ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
|
||||
assessment_method=assessment_method,
|
||||
data_file_path=args_opt.eval_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path)
|
||||
|
|
|
@ -38,12 +38,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|||
_cur_dir = os.getcwd()
|
||||
|
||||
|
||||
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""):
|
||||
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
|
||||
""" do train """
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
|
||||
steps_per_epoch = dataset.get_dataset_size()
|
||||
epoch_num = dataset.get_repeat_count()
|
||||
# optimizer
|
||||
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
|
||||
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
|
||||
|
@ -204,10 +203,10 @@ def run_ner():
|
|||
use_crf=(args_opt.use_crf.lower() == "true"),
|
||||
tag_to_index=tag_to_index, dropout_prob=0.1)
|
||||
if args_opt.do_train.lower() == "true":
|
||||
ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num,
|
||||
ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
|
||||
assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path)
|
||||
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path)
|
||||
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
if save_finetune_checkpoint_path == "":
|
||||
|
@ -218,7 +217,7 @@ def run_ner():
|
|||
ds.get_dataset_size(), epoch_num, "ner")
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num,
|
||||
ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
|
||||
assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path)
|
||||
do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path,
|
||||
|
|
|
@ -100,11 +100,12 @@ def run_pretrain():
|
|||
bert_net_cfg.compute_type = mstype.float32
|
||||
|
||||
|
||||
ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.enable_data_sink, args_opt.data_sink_steps,
|
||||
args_opt.data_dir, args_opt.schema_dir)
|
||||
ds = create_bert_dataset(1, device_num, rank, args_opt.do_shuffle,
|
||||
args_opt.enable_data_sink, args_opt.data_sink_steps,
|
||||
args_opt.data_dir, args_opt.schema_dir)
|
||||
new_repeat_count = args_opt.epoch_size
|
||||
if args_opt.train_steps > 0:
|
||||
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
|
||||
new_repeat_count = min(args_opt.epoch_size, args_opt.train_steps // args_opt.data_sink_steps)
|
||||
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
|
||||
|
||||
if cfg.optimizer == 'Lamb':
|
||||
|
|
|
@ -38,12 +38,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|||
|
||||
_cur_dir = os.getcwd()
|
||||
|
||||
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""):
|
||||
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
|
||||
""" do train """
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
|
||||
steps_per_epoch = dataset.get_dataset_size()
|
||||
epoch_num = dataset.get_repeat_count()
|
||||
# optimizer
|
||||
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
|
||||
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
|
||||
|
@ -181,10 +180,10 @@ def run_squad():
|
|||
netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)
|
||||
|
||||
if args_opt.do_train.lower() == "true":
|
||||
ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num,
|
||||
ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
|
||||
data_file_path=args_opt.train_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path)
|
||||
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path)
|
||||
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
if save_finetune_checkpoint_path == "":
|
||||
load_finetune_checkpoint_dir = _cur_dir
|
||||
|
@ -194,7 +193,7 @@ def run_squad():
|
|||
ds.get_dataset_size(), epoch_num, "squad")
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num,
|
||||
ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
|
||||
data_file_path=args_opt.eval_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path, is_training=False)
|
||||
do_eval(ds, args_opt.vocab_file_path, args_opt.eval_json_path,
|
||||
|
|
|
@ -54,7 +54,6 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
# apply batch operations
|
||||
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(max(new_repeat_count, repeat_count))
|
||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||
return ds, new_repeat_count
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine.datasets as de
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
from mindspore import log as logger
|
||||
from .config import transformer_net_cfg
|
||||
|
||||
def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", enable_data_sink="true",
|
||||
|
@ -42,7 +41,4 @@ def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle
|
|||
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_count)
|
||||
|
||||
ds.channel_name = 'transformer'
|
||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||
return ds, repeat_count
|
||||
return ds
|
||||
|
|
|
@ -125,10 +125,10 @@ def run_transformer_train():
|
|||
else:
|
||||
device_num = 1
|
||||
rank_id = 0
|
||||
dataset, repeat_count = create_transformer_dataset(epoch_count=args.epoch_size, rank_size=device_num,
|
||||
rank_id=rank_id, do_shuffle=args.do_shuffle,
|
||||
enable_data_sink=args.enable_data_sink,
|
||||
dataset_path=args.data_path)
|
||||
dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num,
|
||||
rank_id=rank_id, do_shuffle=args.do_shuffle,
|
||||
enable_data_sink=args.enable_data_sink,
|
||||
dataset_path=args.data_path)
|
||||
|
||||
netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)
|
||||
|
||||
|
@ -165,7 +165,7 @@ def run_transformer_train():
|
|||
|
||||
netwithgrads.set_train(True)
|
||||
model = Model(netwithgrads)
|
||||
model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
|
||||
model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_transformer_train()
|
||||
|
|
|
@ -88,10 +88,10 @@ if __name__ == '__main__':
|
|||
|
||||
# create dataset
|
||||
if args_opt.net == "resnet50":
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=config.epoch_size,
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
|
||||
batch_size=config.batch_size, target=target)
|
||||
else:
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=config.epoch_size,
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
|
||||
batch_size=config.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
|
|
|
@ -105,7 +105,7 @@ if __name__ == '__main__':
|
|||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
if args_opt.do_train:
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
|
||||
repeat_num=epoch_size, batch_size=config.batch_size)
|
||||
batch_size=config.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
|
|
@ -91,7 +91,7 @@ def main():
|
|||
loss_scale = float(args_opt.loss_scale)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0.
|
||||
dataset = create_ssd_dataset(mindrecord_file, repeat_num=args_opt.epoch_size,
|
||||
dataset = create_ssd_dataset(mindrecord_file, repeat_num=1,
|
||||
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
|
||||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
|
|
|
@ -83,7 +83,7 @@ if __name__ == '__main__':
|
|||
mirror_mean=True)
|
||||
init()
|
||||
|
||||
dataset = vgg_create_dataset(args_opt.data_path, cfg.epoch_size)
|
||||
dataset = vgg_create_dataset(args_opt.data_path, 1)
|
||||
batch_num = dataset.get_dataset_size()
|
||||
|
||||
net = vgg16(num_classes=cfg.num_classes)
|
||||
|
|
|
@ -63,7 +63,7 @@ def test_train(configure):
|
|||
data_path = configure.data_path
|
||||
batch_size = configure.batch_size
|
||||
epochs = configure.epochs
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size)
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size)
|
||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||
|
||||
net_builder = ModelBuilder()
|
||||
|
|
|
@ -67,8 +67,8 @@ def test_train_eval(config):
|
|||
data_path = config.data_path
|
||||
batch_size = config.batch_size
|
||||
epochs = config.epochs
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size)
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size)
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size)
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size)
|
||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||
|
||||
|
|
|
@ -85,14 +85,14 @@ def train_and_eval(config):
|
|||
if config.full_batch:
|
||||
context.set_auto_parallel_context(full_batch=True)
|
||||
de.config.set_seed(1)
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||
batch_size=batch_size*get_group_size())
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||
batch_size=batch_size*get_group_size())
|
||||
else:
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
|
||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||
|
|
|
@ -74,9 +74,9 @@ def train_and_eval(config):
|
|||
batch_size = config.batch_size
|
||||
epochs = config.epochs
|
||||
print("epochs is {}".format(epochs))
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
|
||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||
|
|
|
@ -121,7 +121,7 @@ def main():
|
|||
loss_scale = float(args_opt.loss_scale)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
|
||||
dataset = create_yolo_dataset(mindrecord_file, repeat_num=args_opt.epoch_size,
|
||||
dataset = create_yolo_dataset(mindrecord_file,
|
||||
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("Create dataset done!")
|
||||
|
|
|
@ -50,13 +50,20 @@ class MindData:
|
|||
def input_indexs(self):
|
||||
return self._input_indexs
|
||||
|
||||
def device_que(self):
|
||||
def device_que(self, send_epoch_end=True):
|
||||
self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736'
|
||||
self.send_epoch_end = send_epoch_end
|
||||
return self
|
||||
|
||||
def create_tuple_iterator(self):
|
||||
return self.__iter__()
|
||||
|
||||
def send(self):
|
||||
pass
|
||||
|
||||
def stop_send(self):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
return self._size
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ if __name__ == "__main__":
|
|||
epoch_size = 3
|
||||
args_opt.base_size = config.crop_size
|
||||
args_opt.crop_size = config.crop_size
|
||||
train_dataset = create_dataset(args_opt, args_opt.data_url, epoch_size, config.batch_size,
|
||||
train_dataset = create_dataset(args_opt, args_opt.data_url, 1, config.batch_size,
|
||||
usage="train", shuffle=False)
|
||||
dataset_size = train_dataset.get_dataset_size()
|
||||
callback = LossCallBack(dataset_size)
|
||||
|
|
|
@ -120,10 +120,10 @@ def test_transformer():
|
|||
batch_size = 96
|
||||
epoch_size = 3
|
||||
config = get_config(version=version, batch_size=batch_size)
|
||||
dataset, repeat_count = create_transformer_dataset(epoch_count=epoch_size,
|
||||
do_shuffle="false",
|
||||
enable_data_sink="false",
|
||||
dataset_path=DATA_DIR)
|
||||
dataset = create_transformer_dataset(epoch_count=1,
|
||||
do_shuffle="false",
|
||||
enable_data_sink="false",
|
||||
dataset_path=DATA_DIR)
|
||||
|
||||
netwithloss = TransformerNetworkWithLoss(config, True)
|
||||
|
||||
|
@ -146,7 +146,7 @@ def test_transformer():
|
|||
netwithgrads.set_train(True)
|
||||
time_monitor_callback = TimeMonitor(dataset.get_dataset_size())
|
||||
model = Model(netwithgrads)
|
||||
model.train(repeat_count, dataset, callbacks=[time_monitor_callback, callback], dataset_sink_mode=False)
|
||||
model.train(epoch_size, dataset, callbacks=[time_monitor_callback, callback], dataset_sink_mode=False)
|
||||
|
||||
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
||||
loss_value = np.array(callback.loss_list)
|
||||
|
|
|
@ -79,9 +79,9 @@ def test_train_eval():
|
|||
batch_size = config.batch_size
|
||||
epochs = config.epochs
|
||||
print("epochs is {}".format(epochs))
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size,
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size,
|
||||
data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size())
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size,
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size,
|
||||
data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size())
|
||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||
|
|
|
@ -76,9 +76,9 @@ def test_train_eval():
|
|||
batch_size = config.batch_size
|
||||
epochs = config.epochs
|
||||
print("epochs is {}".format(epochs))
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
|
||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||
|
|
|
@ -113,7 +113,7 @@ def test_yolov3():
|
|||
loss_scale = float(loss_scale)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
|
||||
dataset = create_yolo_dataset(mindrecord_file, repeat_num=epoch_size,
|
||||
dataset = create_yolo_dataset(mindrecord_file, repeat_num=1,
|
||||
batch_size=batch_size, device_num=device_num, rank=rank)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("Create dataset done!")
|
||||
|
@ -146,12 +146,12 @@ def test_yolov3():
|
|||
assert loss_value[2] < expect_loss_value[2]
|
||||
|
||||
epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2]
|
||||
expect_epoch_mseconds = 950
|
||||
expect_epoch_mseconds = 2000
|
||||
print("epoch mseconds: {}".format(epoch_mseconds))
|
||||
assert epoch_mseconds <= expect_epoch_mseconds
|
||||
|
||||
per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2]
|
||||
expect_per_step_mseconds = 110
|
||||
expect_per_step_mseconds = 220
|
||||
print("per step mseconds: {}".format(per_step_mseconds))
|
||||
assert per_step_mseconds <= expect_per_step_mseconds
|
||||
print("yolov3 test case passed.")
|
||||
|
|
|
@ -91,6 +91,7 @@ def me_de_train_dataset(sink_mode=False):
|
|||
"""test me de train dataset"""
|
||||
# apply repeat operations
|
||||
repeat_count = 1
|
||||
sink_size = -1
|
||||
batch_size = 16
|
||||
ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids",
|
||||
"next_sentence_labels", "masked_lm_positions",
|
||||
|
@ -99,9 +100,9 @@ def me_de_train_dataset(sink_mode=False):
|
|||
new_repeat_count = repeat_count
|
||||
if sink_mode:
|
||||
repeat_count = 30
|
||||
sink_steps = 100
|
||||
sink_size = 100
|
||||
ori_dataaet_size = ds.get_dataset_size()
|
||||
new_size = sink_steps * batch_size
|
||||
new_size = sink_size * batch_size
|
||||
ds.set_dataset_size(new_size)
|
||||
new_repeat_count = int(repeat_count * ori_dataaet_size // ds.get_dataset_size())
|
||||
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||
|
@ -112,10 +113,9 @@ def me_de_train_dataset(sink_mode=False):
|
|||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_count)
|
||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||
logger.info("repeat_count: {}".format(ds.get_repeat_count()))
|
||||
return ds, new_repeat_count
|
||||
return ds, new_repeat_count, sink_size
|
||||
|
||||
|
||||
def weight_variable(shape):
|
||||
|
@ -157,7 +157,7 @@ class TimeMonitor(Callback):
|
|||
def test_bert_percision():
|
||||
"""test bert percision"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
|
||||
ds, new_repeat_count = me_de_train_dataset()
|
||||
ds, new_repeat_count, _ = me_de_train_dataset()
|
||||
version = os.getenv('VERSION', 'large')
|
||||
batch_size = 16
|
||||
config = get_config(version=version, batch_size=batch_size)
|
||||
|
@ -215,7 +215,7 @@ def test_bert_percision():
|
|||
def test_bert_performance():
|
||||
"""test bert performance"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
|
||||
ds, new_repeat_count = me_de_train_dataset(sink_mode=True)
|
||||
ds, new_repeat_count, sink_size = me_de_train_dataset(sink_mode=True)
|
||||
version = os.getenv('VERSION', 'large')
|
||||
batch_size = 16
|
||||
config = get_config(version=version, batch_size=batch_size)
|
||||
|
@ -251,7 +251,7 @@ def test_bert_performance():
|
|||
param.default_input = weight_variable(value.asnumpy().shape)
|
||||
time_monitor_callback = TimeMonitor(ds.get_dataset_size())
|
||||
model.train(new_repeat_count, ds, callbacks=[time_monitor_callback, callback],
|
||||
dataset_sink_mode=True)
|
||||
dataset_sink_mode=True, sink_size=sink_size)
|
||||
|
||||
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
||||
loss_value = np.array(callback.loss_list)
|
||||
|
|
|
@ -79,7 +79,7 @@ def test_deeplabv3_1p():
|
|||
args_opt.base_size = config.crop_size
|
||||
args_opt.crop_size = config.crop_size
|
||||
args_opt.batch_size = config.batch_size
|
||||
train_dataset = create_dataset(args_opt, data_url, epoch_size, config.batch_size,
|
||||
train_dataset = create_dataset(args_opt, data_url, 1, config.batch_size,
|
||||
usage="eval")
|
||||
dataset_size = train_dataset.get_dataset_size()
|
||||
callback = LossCallBack(dataset_size)
|
||||
|
|
|
@ -155,7 +155,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
|
|||
|
||||
# train dataset
|
||||
dataset = create_dataset(dataset_path=dataset_path, do_train=True,
|
||||
repeat_num=epoch_size, batch_size=config.batch_size)
|
||||
repeat_num=1, batch_size=config.batch_size)
|
||||
|
||||
step_size = dataset.get_dataset_size()
|
||||
eval_interval = config.eval_interval
|
||||
|
@ -163,7 +163,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
|
|||
|
||||
# evalutation dataset
|
||||
eval_dataset = create_dataset(dataset_path=eval_path, do_train=False,
|
||||
repeat_num=epoch_size, batch_size=config.eval_batch_size)
|
||||
repeat_num=1, batch_size=config.eval_batch_size)
|
||||
|
||||
# loss scale
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
@ -260,14 +260,14 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
|
|||
|
||||
# train dataset
|
||||
dataset = create_dataset(dataset_path=dataset_path, do_train=True,
|
||||
repeat_num=epoch_size, batch_size=thor_config.batch_size)
|
||||
repeat_num=1, batch_size=thor_config.batch_size)
|
||||
|
||||
step_size = dataset.get_dataset_size()
|
||||
eval_interval = thor_config.eval_interval
|
||||
|
||||
# evalutation dataset
|
||||
eval_dataset = create_dataset(dataset_path=eval_path, do_train=False,
|
||||
repeat_num=epoch_size, batch_size=thor_config.eval_batch_size)
|
||||
repeat_num=1, batch_size=thor_config.eval_batch_size)
|
||||
|
||||
# loss scale
|
||||
loss_scale = FixedLossScaleManager(thor_config.loss_scale, drop_overflow_update=False)
|
||||
|
|
|
@ -136,7 +136,7 @@ if __name__ == '__main__':
|
|||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
if args_opt.do_train:
|
||||
dataset = create_dataset(epoch_size)
|
||||
dataset = create_dataset(1)
|
||||
batch_num = dataset.get_dataset_size()
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=10)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck)
|
||||
|
|
|
@ -140,7 +140,7 @@ def train_process(epoch_size, num_classes, batch_size):
|
|||
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
dataset = create_dataset(epoch_size, training=True, batch_size=batch_size)
|
||||
dataset = create_dataset(1, training=True, batch_size=batch_size)
|
||||
loss_cb = LossGet()
|
||||
model.train(epoch_size, dataset, callbacks=[loss_cb])
|
||||
|
||||
|
|
|
@ -164,7 +164,7 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size,
|
|||
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
dataset = create_dataset(epoch_size, training=True,
|
||||
dataset = create_dataset(1, training=True,
|
||||
batch_size=batch_size, rank_id=device_id, rank_size=device_num,
|
||||
enable_hccl=enable_hccl)
|
||||
|
||||
|
|
|
@ -91,8 +91,9 @@ SET(DE_UT_SRCS
|
|||
cyclic_array_test.cc
|
||||
perf_data_test.cc
|
||||
c_api_test.cc
|
||||
tensor_op_fusion_pass_test.cc
|
||||
tensor_op_fusion_pass_test.cc
|
||||
sliding_window_op_test.cc
|
||||
epoch_ctrl_op_test.cc
|
||||
)
|
||||
|
||||
add_executable(de_ut_tests ${DE_UT_SRCS})
|
||||
|
|
|
@ -397,23 +397,21 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
|
|||
|
||||
std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true);
|
||||
|
||||
std::shared_ptr<CacheMergeOp> myMergeOp;
|
||||
rc = CacheMergeOp::Builder().SetNumWorkers(3).SetOpConnectorSize(3).SetNumCleaner(2).SetClient(myClient).Build(
|
||||
&myMergeOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
// In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op.
|
||||
// Rather than manually build this, the way to do it is to choose the position of the cache in the tree by
|
||||
// adding a CacheOp. Then, the tree prepare code will drive a transform that will remove the CacheOp and
|
||||
// replace it with the required tree structures for cache lookup op and cache merge op.
|
||||
|
||||
std::shared_ptr<CacheLookupOp> myLookupOp;
|
||||
rc = CacheLookupOp::Builder()
|
||||
.SetNumWorkers(3)
|
||||
.SetOpConnectorSize(3)
|
||||
std::shared_ptr<CacheOp> myCacheOp;
|
||||
rc = CacheOp::Builder()
|
||||
.SetNumWorkers(4)
|
||||
.SetClient(myClient)
|
||||
.SetSampler(seq_sampler)
|
||||
.Build(&myLookupOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
.SetRowsPerBuffer(3)
|
||||
.Build(&myCacheOp);
|
||||
|
||||
std::shared_ptr<ImageFolderOp> so;
|
||||
ImageFolderOp::Builder builder;
|
||||
builder.SetSampler(myLookupOp)
|
||||
builder.SetSampler(std::move(seq_sampler))
|
||||
.SetOpConnectorSize(3)
|
||||
.SetNumWorkers(3)
|
||||
.SetRowsPerBuffer(2)
|
||||
|
@ -432,20 +430,18 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
|
|||
auto myTree = std::make_shared<ExecutionTree>();
|
||||
rc = myTree->AssociateNode(so);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = myTree->AssociateNode(myLookupOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = myTree->AssociateNode(myMergeOp);
|
||||
|
||||
rc = myTree->AssociateNode(myCacheOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
rc = myTree->AssociateNode(myRepeatOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = myTree->AssignRoot(myRepeatOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
rc = myRepeatOp->AddChild(myMergeOp);
|
||||
rc = myRepeatOp->AddChild(myCacheOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = myMergeOp->AddChild(myLookupOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = myMergeOp->AddChild(so);
|
||||
rc = myCacheOp->AddChild(so);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
rc = myTree->Prepare();
|
||||
|
|
|
@ -0,0 +1,639 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include <memory>
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
|
||||
std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path,
|
||||
bool shuf = false, std::shared_ptr<Sampler> sampler = nullptr,
|
||||
std::map<std::string, int32_t> map = {}, bool decode = false);
|
||||
|
||||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
|
||||
class MindDataTestEpochCtrlOp : public UT::DatasetOpTesting {
|
||||
public:
|
||||
void SetUp() override {
|
||||
DatasetOpTesting::SetUp();
|
||||
folder_path = datasets_root_path_ + "/testPK/data";
|
||||
|
||||
GlobalInit();
|
||||
|
||||
// Start with an empty execution tree
|
||||
my_tree_ = std::make_shared<ExecutionTree>();
|
||||
|
||||
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)});
|
||||
rc = my_tree_->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(my_tree_);
|
||||
TensorMap tensor_map;
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
int32_t i = 0;
|
||||
while (tensor_map.size() != 0) {
|
||||
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
|
||||
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
|
||||
// Dump all the image into string, to be used as a comparison later.
|
||||
golden_imgs.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size());
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ExecutionTree> my_tree_;
|
||||
Status rc;
|
||||
std::string golden_imgs;
|
||||
std::string folder_path;
|
||||
int32_t label = 0;
|
||||
std::string result;
|
||||
int32_t img_class[4] = {0, 1, 2, 3};
|
||||
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_AutoInjectEpoch) {
|
||||
MS_LOG(WARNING) << "Doing ImageFolder_AutoInjectEpoch.";
|
||||
|
||||
int32_t num_epoch = 2 + std::rand() % 5;
|
||||
|
||||
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)});
|
||||
rc = my_tree_->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(DEBUG) << "num_epoch: " << num_epoch;
|
||||
std::string golden = golden_imgs;
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(my_tree_);
|
||||
TensorMap tensor_map;
|
||||
uint64_t i = 0;
|
||||
for (int epoch = 0; epoch < num_epoch; epoch++) {
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
while (tensor_map.size() != 0) {
|
||||
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
|
||||
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
|
||||
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
|
||||
// Dump all the image into string, to be used as a comparison later.
|
||||
result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size());
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
i++;
|
||||
}
|
||||
EXPECT_TRUE(result == golden);
|
||||
result.clear();
|
||||
|
||||
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
|
||||
}
|
||||
|
||||
EXPECT_TRUE(i == 44 * num_epoch);
|
||||
|
||||
// Try to fetch data beyond the specified number of epochs.
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch) {
|
||||
MS_LOG(WARNING) << "Doing ImageFolder_Epoch.";
|
||||
|
||||
int32_t num_epoch = 2 + std::rand() % 5;
|
||||
|
||||
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)});
|
||||
rc = my_tree_->Prepare(num_epoch);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(DEBUG) << "num_epoch: " << num_epoch;
|
||||
std::string golden = golden_imgs;
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(my_tree_);
|
||||
TensorMap tensor_map;
|
||||
uint64_t i = 0;
|
||||
for (int epoch = 0; epoch < num_epoch; epoch++) {
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
while (tensor_map.size() != 0) {
|
||||
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
|
||||
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
|
||||
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
|
||||
// Dump all the image into string, to be used as a comparison later.
|
||||
result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size());
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
i++;
|
||||
}
|
||||
EXPECT_TRUE(result == golden);
|
||||
result.clear();
|
||||
|
||||
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
|
||||
}
|
||||
|
||||
EXPECT_TRUE(i == 44 * num_epoch);
|
||||
|
||||
// Try to fetch data beyond the specified number of epochs.
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_FALSE(rc.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Epoch) {
|
||||
MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch.";
|
||||
|
||||
int32_t num_epoch = 2 + std::rand() % 5;
|
||||
|
||||
int32_t num_repeats = 2;
|
||||
std::shared_ptr<RepeatOp> repeat_op;
|
||||
rc = RepeatOp::Builder(num_repeats).Build(&repeat_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op});
|
||||
rc = my_tree_->Prepare(num_epoch);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats;
|
||||
std::string golden = golden_imgs;
|
||||
for (int i = 1; i < num_repeats; i++) {
|
||||
golden += golden_imgs;
|
||||
}
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(my_tree_);
|
||||
TensorMap tensor_map;
|
||||
uint64_t i = 0;
|
||||
for (int epoch = 0; epoch < num_epoch; epoch++) {
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
while (tensor_map.size() != 0) {
|
||||
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
|
||||
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
|
||||
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
|
||||
// Dump all the image into string, to be used as a comparison later.
|
||||
result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size());
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
i++;
|
||||
}
|
||||
EXPECT_TRUE(result == golden);
|
||||
result.clear();
|
||||
|
||||
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
|
||||
}
|
||||
|
||||
EXPECT_TRUE(i == 44 * num_repeats * num_epoch);
|
||||
|
||||
// Try to fetch data beyond the specified number of epochs.
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_FALSE(rc.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Repeat_Epoch) {
|
||||
MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Repeat_Epoch.";
|
||||
|
||||
int32_t num_epoch = 2 + std::rand() % 5;
|
||||
|
||||
int32_t num_repeats = 2;
|
||||
std::shared_ptr<RepeatOp> repeat_op;
|
||||
rc = RepeatOp::Builder(num_repeats).Build(&repeat_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
int32_t num_repeats_2 = 3;
|
||||
std::shared_ptr<RepeatOp> repeat_op_2;
|
||||
rc = RepeatOp::Builder(num_repeats_2).Build(&repeat_op_2);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op, repeat_op_2});
|
||||
rc = my_tree_->Prepare(num_epoch);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats << ". num_repeat_2: " << num_repeats_2;
|
||||
std::string golden;
|
||||
for (int j = 0; j < num_repeats_2; j++) {
|
||||
for (int i = 0; i < num_repeats; i++) {
|
||||
golden += golden_imgs;
|
||||
}
|
||||
}
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(my_tree_);
|
||||
TensorMap tensor_map;
|
||||
uint64_t i = 0;
|
||||
for (int epoch = 0; epoch < num_epoch; epoch++) {
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
while (tensor_map.size() != 0) {
|
||||
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
|
||||
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
|
||||
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
|
||||
// Dump all the image into string, to be used as a comparison later.
|
||||
result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size());
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(result.size(), golden.size());
|
||||
EXPECT_TRUE(result == golden);
|
||||
result.clear();
|
||||
|
||||
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 44 * num_epoch * num_repeats * num_repeats_2);
|
||||
|
||||
// Try to fetch data beyond the specified number of epochs.
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_FALSE(rc.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch_Inf) {
|
||||
MS_LOG(WARNING) << "Doing ImageFolder_Epoch_Inf.";
|
||||
|
||||
// if num_epoch == -1, it means infinity.
|
||||
int32_t num_epoch = -1;
|
||||
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)});
|
||||
rc = my_tree_->Prepare(num_epoch);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(my_tree_);
|
||||
TensorMap tensor_map;
|
||||
uint64_t i = 0;
|
||||
|
||||
// For this test, we stop at stop_at_epoch number.
|
||||
int32_t stop_at_epoch = 2 + std::rand() % 6;
|
||||
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch;
|
||||
for (int epoch = 0; epoch < stop_at_epoch; epoch++) {
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
while (tensor_map.size() != 0) {
|
||||
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
|
||||
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
|
||||
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
|
||||
// Dump all the image into string, to be used as a comparison later.
|
||||
result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size());
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(result, golden_imgs);
|
||||
result.clear();
|
||||
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
|
||||
}
|
||||
EXPECT_TRUE(i == 44 * stop_at_epoch);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Repeat_Epoch_Inf) {
|
||||
MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch_Inf.";
|
||||
|
||||
// if num_epoch == -1, it means infinity.
|
||||
int32_t num_epoch = -1;
|
||||
|
||||
int32_t num_repeats = 2;
|
||||
std::shared_ptr<RepeatOp> repeat_op;
|
||||
rc = RepeatOp::Builder(num_repeats).Build(&repeat_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
int32_t num_repeats_2 = 3;
|
||||
std::shared_ptr<RepeatOp> repeat_op_2;
|
||||
rc = RepeatOp::Builder(num_repeats_2).Build(&repeat_op_2);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op, repeat_op_2});
|
||||
rc = my_tree_->Prepare(num_epoch);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats << ". num_repeat_2: " << num_repeats_2;
|
||||
std::string golden;
|
||||
for (int j = 0; j < num_repeats_2; j++) {
|
||||
for (int i = 0; i < num_repeats; i++) {
|
||||
golden += golden_imgs;
|
||||
}
|
||||
}
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(my_tree_);
|
||||
TensorMap tensor_map;
|
||||
uint64_t i = 0;
|
||||
|
||||
// For this test, we stop at stop_at_epoch number.
|
||||
int32_t stop_at_epoch = 2 + std::rand() % 6;
|
||||
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch;
|
||||
for (int epoch = 0; epoch < stop_at_epoch; epoch++) {
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
while (tensor_map.size() != 0) {
|
||||
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
|
||||
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
|
||||
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
|
||||
// Dump all the image into string, to be used as a comparison later.
|
||||
result.append((char *) tensor_map["image"]->GetBuffer(), (int64_t) tensor_map["image"]->Size());
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(result, golden);
|
||||
result.clear();
|
||||
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
|
||||
}
|
||||
EXPECT_TRUE(i == 44 * stop_at_epoch * num_repeats * num_repeats_2);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch_ChildItr) {
|
||||
MS_LOG(WARNING) << "Doing ImageFolder_Epoch_ChildItr.";
|
||||
|
||||
int32_t num_epoch = 2 + std::rand() % 5;
|
||||
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)});
|
||||
rc = my_tree_->Prepare(num_epoch);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(INFO) << "num_epoch: " << num_epoch;
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
ChildIterator ci(my_tree_->root().get(), 0, 0);
|
||||
TensorRow tensor_row;
|
||||
uint64_t total_sample = 0;
|
||||
uint64_t i = 0;
|
||||
uint32_t epoch = 0;
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
while(!ci.eof_handled()) {
|
||||
i = 0;
|
||||
while (tensor_row.size() != 0) {
|
||||
tensor_row[1]->GetItemAt<int32_t>(&label, {});
|
||||
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
|
||||
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
|
||||
// Dump all the image into string, to be used as a comparison later.
|
||||
result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size());
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
i++;
|
||||
}
|
||||
|
||||
epoch++;
|
||||
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
|
||||
EXPECT_TRUE(result == golden_imgs);
|
||||
result.clear();
|
||||
EXPECT_TRUE(i == 44);
|
||||
total_sample += i;
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
}
|
||||
EXPECT_TRUE(total_sample == 44 * num_epoch);
|
||||
|
||||
// Try to fetch data after last epoch ends.
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(tensor_row.empty());
|
||||
EXPECT_FALSE(rc.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Epoch_ChildItr) {
|
||||
MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch_ChildItr.";
|
||||
|
||||
int32_t num_epoch = 2 + std::rand() % 5;
|
||||
|
||||
int32_t num_repeats = 2;
|
||||
std::shared_ptr<RepeatOp> repeat_op;
|
||||
rc = RepeatOp::Builder(num_repeats).Build(&repeat_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op});
|
||||
rc = my_tree_->Prepare(num_epoch);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats;
|
||||
std::string golden;
|
||||
for (int i = 0; i < num_repeats; i++) {
|
||||
golden += golden_imgs;
|
||||
}
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
ChildIterator ci(my_tree_->root().get(), 0, 0);
|
||||
TensorRow tensor_row;
|
||||
uint64_t total_sample = 0;
|
||||
uint64_t i = 0;
|
||||
uint32_t epoch = 0;
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
while(!ci.eof_handled()) {
|
||||
i = 0;
|
||||
while (tensor_row.size() != 0) {
|
||||
tensor_row[1]->GetItemAt<int32_t>(&label, {});
|
||||
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
|
||||
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
|
||||
// Dump all the image into string, to be used as a comparison later.
|
||||
result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size());
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
i++;
|
||||
}
|
||||
|
||||
epoch++;
|
||||
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
|
||||
EXPECT_TRUE(result == golden);
|
||||
result.clear();
|
||||
EXPECT_TRUE(i == 44 * num_repeats);
|
||||
total_sample += i;
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
}
|
||||
EXPECT_TRUE(total_sample == 44 * num_epoch * num_repeats);
|
||||
|
||||
// Try to fetch data after last epoch ends.
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(tensor_row.empty());
|
||||
EXPECT_FALSE(rc.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Repeat_Epoch_ChildItr) {
|
||||
MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Repeat_Epoch_ChildItr.";
|
||||
|
||||
int32_t num_epoch = 2 + std::rand() % 5;
|
||||
|
||||
int32_t num_repeats = 2;
|
||||
std::shared_ptr<RepeatOp> repeat_op;
|
||||
rc = RepeatOp::Builder(num_repeats).Build(&repeat_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
int32_t num_repeats_2 = 3;
|
||||
std::shared_ptr<RepeatOp> repeat_op_2;
|
||||
rc = RepeatOp::Builder(num_repeats_2).Build(&repeat_op_2);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op, repeat_op_2});
|
||||
rc = my_tree_->Prepare(num_epoch);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats << ". num_repeat_2: " << num_repeats_2;
|
||||
std::string golden;
|
||||
for (int j = 0; j < num_repeats_2; j++) {
|
||||
for (int i = 0; i < num_repeats; i++) {
|
||||
golden += golden_imgs;
|
||||
}
|
||||
}
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
ChildIterator ci(my_tree_->root().get(), 0, 0);
|
||||
TensorRow tensor_row;
|
||||
uint64_t total_sample = 0;
|
||||
uint64_t i = 0;
|
||||
uint32_t epoch = 0;
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
while(!ci.eof_handled()) {
|
||||
i = 0;
|
||||
while (tensor_row.size() != 0) {
|
||||
tensor_row[1]->GetItemAt<int32_t>(&label, {});
|
||||
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
|
||||
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
|
||||
// Dump all the image into string, to be used as a comparison later.
|
||||
result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size());
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
i++;
|
||||
}
|
||||
|
||||
epoch++;
|
||||
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
|
||||
EXPECT_TRUE(result == golden);
|
||||
result.clear();
|
||||
EXPECT_TRUE(i == 44 * num_repeats * num_repeats_2);
|
||||
total_sample += i;
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
}
|
||||
EXPECT_TRUE(total_sample == 44 * num_epoch * num_repeats * num_repeats_2);
|
||||
|
||||
// Try to fetch data after last epoch ends.
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(tensor_row.empty());
|
||||
EXPECT_FALSE(rc.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Epoch_Inf_ChildItr) {
|
||||
MS_LOG(WARNING) << "Doing ImageFolder_Epoch_Inf_ChildItr.";
|
||||
|
||||
// if num_epoch == -1, it means infinity.
|
||||
int32_t num_epoch = -1;
|
||||
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false)});
|
||||
rc = my_tree_->Prepare(num_epoch);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
ChildIterator ci(my_tree_->root().get(), 0, 0);
|
||||
TensorRow tensor_row;
|
||||
uint64_t i = 0;
|
||||
|
||||
// For this test, we stop at a random number between 0 - 100 epochs.
|
||||
int32_t stop_at_epoch = 2 + std::rand() % 5;
|
||||
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch;
|
||||
for (int epoch = 0; epoch < stop_at_epoch; epoch++) {
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
while (tensor_row.size() != 0) {
|
||||
tensor_row[1]->GetItemAt<int32_t>(&label, {});
|
||||
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
|
||||
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
|
||||
// Dump all the image into string, to be used as a comparison later.
|
||||
result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size());
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
i++;
|
||||
}
|
||||
EXPECT_TRUE(result == golden_imgs);
|
||||
result.clear();
|
||||
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
|
||||
}
|
||||
EXPECT_TRUE(i == 44 * stop_at_epoch);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestEpochCtrlOp, ImageFolder_Repeat_Epoch_Inf_ChildItr) {
|
||||
MS_LOG(WARNING) << "Doing ImageFolder_Repeat_Epoch_Inf_ChildItr.";
|
||||
|
||||
// if num_epoch == -1, it means infinity.
|
||||
int32_t num_epoch = -1;
|
||||
int32_t num_repeats = 2;
|
||||
std::shared_ptr<RepeatOp> repeat_op;
|
||||
rc = RepeatOp::Builder(num_repeats).Build(&repeat_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_tree_ = Build({ImageFolder(2, 2, 32, folder_path, false), repeat_op});
|
||||
rc = my_tree_->Prepare(num_epoch);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". num_repeat: " << num_repeats;
|
||||
std::string golden;
|
||||
for (int i = 0; i < num_repeats; i++) {
|
||||
golden += golden_imgs;
|
||||
}
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
ChildIterator ci(my_tree_->root().get(), 0, 0);
|
||||
TensorRow tensor_row;
|
||||
uint64_t i = 0;
|
||||
|
||||
// For this test, we stop at a random number between 0 - 100 epochs.
|
||||
int32_t stop_at_epoch = 2 + std::rand() % 5;
|
||||
MS_LOG(DEBUG) << "num_epoch: " << num_epoch << ". Stop at epoch: " << stop_at_epoch;
|
||||
for (int epoch = 0; epoch < stop_at_epoch; epoch++) {
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
while (tensor_row.size() != 0) {
|
||||
tensor_row[1]->GetItemAt<int32_t>(&label, {});
|
||||
MS_LOG(DEBUG) << "row:" << i << "\tlabel:" << label << "\n";
|
||||
EXPECT_TRUE(img_class[(i % 44) / 11] == label);
|
||||
// Dump all the image into string, to be used as a comparison later.
|
||||
result.append((char *) tensor_row[0]->GetBuffer(), (int64_t) tensor_row[0]->Size());
|
||||
rc = ci.FetchNextTensorRow(&tensor_row);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
i++;
|
||||
}
|
||||
EXPECT_TRUE(result == golden);
|
||||
result.clear();
|
||||
MS_LOG(DEBUG) << "Current epoch: " << epoch << ". Sample count: " << i;
|
||||
}
|
||||
EXPECT_TRUE(i == 44 * stop_at_epoch * num_repeats);
|
||||
}
|
|
@ -46,7 +46,8 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) {
|
|||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(my_tfreader_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
my_tree->AssociateNode(parent_op);
|
||||
rc = my_tree->AssociateNode(parent_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
ASSERT_NE(parent_op, nullptr);
|
||||
ASSERT_NE(my_tfreader_op, nullptr);
|
||||
parent_op->AddChild(std::move(my_tfreader_op));
|
||||
|
|
|
@ -104,9 +104,11 @@ def test_cache_map_basic3():
|
|||
decode_op = c_vision.Decode()
|
||||
ds1 = ds1.repeat(4)
|
||||
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
|
||||
print("ds1.dataset_size is ", ds1.get_dataset_size())
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
print("get data from dataset")
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -152,6 +154,10 @@ def test_cache_map_failure1():
|
|||
|
||||
if __name__ == '__main__':
|
||||
test_cache_map_basic1()
|
||||
print("test_cache_map_basic1 success.")
|
||||
test_cache_map_basic2()
|
||||
print("test_cache_map_basic2 success.")
|
||||
test_cache_map_basic3()
|
||||
print("test_cache_map_basic3 success.")
|
||||
test_cache_map_failure1()
|
||||
print("test_cache_map_failure1 success.")
|
||||
|
|
|
@ -238,7 +238,7 @@ def test_tfrecord_shard_equal_rows():
|
|||
def test_tfrecord_no_schema_columns_list():
|
||||
logger.info("test_tfrecord_no_schema_columns_list")
|
||||
data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"])
|
||||
row = data.create_dict_iterator().get_next()
|
||||
row = data.create_dict_iterator().__next__()
|
||||
assert row["col_sint16"] == [-32768]
|
||||
|
||||
with pytest.raises(KeyError) as info:
|
||||
|
@ -258,7 +258,7 @@ def test_tfrecord_schema_columns_list():
|
|||
schema.add_column('col_sint32', de_type=mstype.int64, shape=[1])
|
||||
schema.add_column('col_sint64', de_type=mstype.int64, shape=[1])
|
||||
data = ds.TFRecordDataset(FILES, schema=schema, shuffle=False, columns_list=["col_sint16"])
|
||||
row = data.create_dict_iterator().get_next()
|
||||
row = data.create_dict_iterator().__next__()
|
||||
assert row["col_sint16"] == [-32768]
|
||||
|
||||
with pytest.raises(KeyError) as info:
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import time
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
|
@ -35,6 +37,8 @@ def test_case_0():
|
|||
|
||||
data = data.device_que()
|
||||
data.send()
|
||||
time.sleep(0.1)
|
||||
data.stop_send()
|
||||
|
||||
|
||||
def test_case_1():
|
||||
|
@ -58,6 +62,8 @@ def test_case_1():
|
|||
|
||||
data = data.device_que()
|
||||
data.send()
|
||||
time.sleep(0.1)
|
||||
data.stop_send()
|
||||
|
||||
|
||||
def test_case_2():
|
||||
|
@ -84,6 +90,8 @@ def test_case_2():
|
|||
data = data.device_que()
|
||||
assert data.get_repeat_count() == 2
|
||||
data.send()
|
||||
time.sleep(0.1)
|
||||
data.stop_send()
|
||||
|
||||
|
||||
def test_case_3():
|
||||
|
@ -109,13 +117,17 @@ def test_case_3():
|
|||
|
||||
data = data.device_que()
|
||||
data.send()
|
||||
time.sleep(0.1)
|
||||
data.stop_send()
|
||||
|
||||
|
||||
def test_case_tf_file():
|
||||
data = ds.TFRecordDataset(TF_FILES, TF_SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
|
||||
|
||||
data = data.to_device(num_batch=10)
|
||||
data = data.to_device()
|
||||
data.send()
|
||||
time.sleep(0.1)
|
||||
data.stop_send()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -0,0 +1,608 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing Epoch Control op in DE
|
||||
"""
|
||||
import itertools
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
|
||||
def diff_mse(in1, in2):
|
||||
"""
|
||||
diff_mse
|
||||
"""
|
||||
mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
|
||||
return mse * 100
|
||||
|
||||
def test_cifar10():
|
||||
"""
|
||||
dataset parameter
|
||||
"""
|
||||
logger.info("Test dataset parameter")
|
||||
data_dir_10 = "../data/dataset/testCifar10Data"
|
||||
num_repeat = 2
|
||||
batch_size = 32
|
||||
limit_dataset = 100
|
||||
# apply dataset operations
|
||||
data1 = ds.Cifar10Dataset(data_dir_10, limit_dataset)
|
||||
data1 = data1.repeat(num_repeat)
|
||||
data1 = data1.batch(batch_size, True)
|
||||
num_epoch = 5
|
||||
# iter1 will always assume there is a next epoch and never shutdown.
|
||||
iter1 = data1.create_tuple_iterator()
|
||||
epoch_count = 0
|
||||
sample_count = 0
|
||||
for _ in range(num_epoch):
|
||||
row_count = 0
|
||||
for _ in iter1:
|
||||
# in this example, each dictionary has keys "image" and "label"
|
||||
row_count += 1
|
||||
assert row_count == int(limit_dataset * num_repeat / batch_size)
|
||||
logger.debug("row_count: ", row_count)
|
||||
epoch_count += 1
|
||||
sample_count += row_count
|
||||
assert epoch_count == num_epoch
|
||||
logger.debug("total epochs: ", epoch_count)
|
||||
assert sample_count == int(limit_dataset * num_repeat / batch_size) * num_epoch
|
||||
logger.debug("total sample: ", sample_count)
|
||||
|
||||
|
||||
def test_decode_op():
|
||||
"""
|
||||
Test Decode op
|
||||
"""
|
||||
logger.info("test_decode_op")
|
||||
|
||||
# Decode with rgb format set to True
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
|
||||
# Serialize and Load dataset requires using vision.Decode instead of vision.Decode().
|
||||
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
|
||||
|
||||
# Second dataset
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
|
||||
num_epoch = 5
|
||||
# iter1 will always assume there is a next epoch and never shutdown.
|
||||
iter1 = data1.create_dict_iterator()
|
||||
# iter 2 will stop and shutdown pipeline after num_epoch
|
||||
iter2 = data2.create_dict_iterator(num_epoch)
|
||||
for _ in range(num_epoch):
|
||||
i = 0
|
||||
for item1, item2 in itertools.zip_longest(iter1, iter2):
|
||||
actual = item1["image"]
|
||||
expected = cv2.imdecode(item2["image"], cv2.IMREAD_COLOR)
|
||||
expected = cv2.cvtColor(expected, cv2.COLOR_BGR2RGB)
|
||||
assert actual.shape == expected.shape
|
||||
diff = actual - expected
|
||||
mse = np.sum(np.power(diff, 2))
|
||||
assert mse == 0
|
||||
i = i + 1
|
||||
assert i == 3
|
||||
|
||||
# Users have the option to manually stop the iterator, or rely on garbage collector.
|
||||
iter1.stop()
|
||||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
iter1.__next__()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
iter2.__next__()
|
||||
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
|
||||
assert err_msg in str(info.value)
|
||||
|
||||
|
||||
# Generate 1d int numpy array from 0 - 63
|
||||
def generator_1d():
|
||||
"""
|
||||
generator
|
||||
"""
|
||||
for i in range(64):
|
||||
yield (np.array([i]),)
|
||||
|
||||
|
||||
def test_generator_dict_0():
|
||||
"""
|
||||
test generator dict 0
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
i = 0
|
||||
# create the iterator inside the loop declaration
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item["data"], golden)
|
||||
i = i + 1
|
||||
|
||||
def test_generator_dict_1():
|
||||
"""
|
||||
test generator dict 1
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
# BAD. Do not create iterator every time inside.
|
||||
# Create iterator outside the epoch for loop.
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item["data"], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
|
||||
def test_generator_dict_2():
|
||||
"""
|
||||
test generator dict 2
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
iter1 = data1.create_dict_iterator()
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item["data"], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
|
||||
# iter1 is still alive and running.
|
||||
item1 = iter1.__next__()
|
||||
assert item1
|
||||
# rely on garbage collector to destroy iter1
|
||||
|
||||
def test_generator_dict_3():
|
||||
"""
|
||||
test generator dict 3
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
iter1 = data1.create_dict_iterator()
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item["data"], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
# optional
|
||||
iter1.stop()
|
||||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
iter1.__next__()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
|
||||
|
||||
def test_generator_dict_4():
|
||||
"""
|
||||
test generator dict 4
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
iter1 = data1.create_dict_iterator(num_epochs=10)
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item["data"], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
iter1.__next__()
|
||||
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
|
||||
assert err_msg in str(info.value)
|
||||
|
||||
def test_generator_dict_4_1():
|
||||
"""
|
||||
test generator dict 4_1
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
# epoch ctrl op will not be injected if num_epochs is 1.
|
||||
iter1 = data1.create_dict_iterator(num_epochs=1)
|
||||
for _ in range(1):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item["data"], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
iter1.__next__()
|
||||
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
|
||||
assert err_msg in str(info.value)
|
||||
|
||||
def test_generator_dict_4_2():
|
||||
"""
|
||||
test generator dict 4_2
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
# repeat will not be injected when num repeat is 1.
|
||||
data1 = data1.repeat(1)
|
||||
# epoch ctrl op will not be injected if num_epochs is 1.
|
||||
iter1 = data1.create_dict_iterator(num_epochs=1)
|
||||
for _ in range(1):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item["data"], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
iter1.__next__()
|
||||
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
|
||||
assert err_msg in str(info.value)
|
||||
|
||||
def test_generator_dict_5():
|
||||
"""
|
||||
test generator dict 5
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
iter1 = data1.create_dict_iterator(num_epochs=11)
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item["data"], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
|
||||
# still one more epoch left in the iter1.
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item["data"], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
|
||||
# now iter1 has been exhausted, c++ pipeline has been shut down.
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
iter1.__next__()
|
||||
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
|
||||
assert err_msg in str(info.value)
|
||||
|
||||
# Test tuple iterator
|
||||
|
||||
def test_generator_tuple_0():
|
||||
"""
|
||||
test generator tuple 0
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
i = 0
|
||||
# create the iterator inside the loop declaration
|
||||
for item in data1.create_tuple_iterator(): # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
|
||||
def test_generator_tuple_1():
|
||||
"""
|
||||
test generator tuple 1
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
# BAD. Do not create iterator every time inside.
|
||||
# Create iterator outside the epoch for loop.
|
||||
for item in data1.create_tuple_iterator(): # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
|
||||
def test_generator_tuple_2():
|
||||
"""
|
||||
test generator tuple 2
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
iter1 = data1.create_tuple_iterator()
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
|
||||
# iter1 is still alive and running.
|
||||
item1 = iter1.__next__()
|
||||
assert item1
|
||||
# rely on garbage collector to destroy iter1
|
||||
|
||||
def test_generator_tuple_3():
|
||||
"""
|
||||
test generator tuple 3
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
iter1 = data1.create_tuple_iterator()
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
# optional
|
||||
iter1.stop()
|
||||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
iter1.__next__()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
|
||||
|
||||
def test_generator_tuple_4():
|
||||
"""
|
||||
test generator tuple 4
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=10)
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
iter1.__next__()
|
||||
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
|
||||
assert err_msg in str(info.value)
|
||||
|
||||
|
||||
def test_generator_tuple_5():
|
||||
"""
|
||||
test generator tuple 5
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=11)
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
|
||||
# still one more epoch left in the iter1.
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
|
||||
# now iter1 has been exhausted, c++ pipeline has been shut down.
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
iter1.__next__()
|
||||
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
|
||||
assert err_msg in str(info.value)
|
||||
|
||||
# Test with repeat
|
||||
def test_generator_tuple_repeat_1():
|
||||
"""
|
||||
test generator tuple repeat 1
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.repeat(2)
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=11)
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i % 64])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64 * 2
|
||||
|
||||
# still one more epoch left in the iter1.
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i % 64])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64 * 2
|
||||
|
||||
# now iter1 has been exhausted, c++ pipeline has been shut down.
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
iter1.__next__()
|
||||
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
|
||||
assert err_msg in str(info.value)
|
||||
|
||||
|
||||
# Test with repeat
|
||||
def test_generator_tuple_repeat_repeat_1():
|
||||
"""
|
||||
test generator tuple repeat repeat 1
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.repeat(2)
|
||||
data1 = data1.repeat(3)
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=11)
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i % 64])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64 * 2 * 3
|
||||
|
||||
# still one more epoch left in the iter1.
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i % 64])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64 * 2 * 3
|
||||
|
||||
# now iter1 has been exhausted, c++ pipeline has been shut down.
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
iter1.__next__()
|
||||
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
|
||||
assert err_msg in str(info.value)
|
||||
|
||||
|
||||
def test_generator_tuple_repeat_repeat_2():
|
||||
"""
|
||||
test generator tuple repeat repeat 2
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.repeat(2)
|
||||
data1 = data1.repeat(3)
|
||||
iter1 = data1.create_tuple_iterator()
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i % 64])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64 * 2 * 3
|
||||
# optional
|
||||
iter1.stop()
|
||||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
iter1.__next__()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
|
||||
def test_generator_tuple_repeat_repeat_3():
|
||||
"""
|
||||
test generator tuple repeat repeat 3
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.repeat(2)
|
||||
data1 = data1.repeat(3)
|
||||
iter1 = data1.create_tuple_iterator()
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i % 64])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64 * 2 * 3
|
||||
|
||||
for _ in range(5):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i % 64])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64 * 2 * 3
|
||||
|
||||
# rely on garbage collector to destroy iter1
|
||||
|
||||
def test_generator_reusedataset():
|
||||
"""
|
||||
test generator reusedataset
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.repeat(2)
|
||||
iter1 = data1.create_tuple_iterator()
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i % 64])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64 * 2
|
||||
|
||||
data1 = data1.repeat(3)
|
||||
iter1 = data1.create_tuple_iterator()
|
||||
for _ in range(5):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i % 64])
|
||||
assert np.array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64 * 2 * 3
|
||||
|
||||
data1 = data1.batch(2)
|
||||
iter1 = data1.create_dict_iterator()
|
||||
for _ in range(5):
|
||||
i = 0
|
||||
sample = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([[i % 64], [(i + 1) % 64]])
|
||||
assert np.array_equal(item["data"], golden)
|
||||
i = i + 2
|
||||
sample = sample + 1
|
||||
assert sample == 64 * 3
|
||||
|
||||
# rely on garbage collector to destroy iter1
|
|
@ -87,7 +87,7 @@ def test_five_crop_error_msg():
|
|||
data = data.map(input_columns=["image"], operations=transform())
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
data.create_tuple_iterator().get_next()
|
||||
data.create_tuple_iterator().__next__()
|
||||
error_msg = "TypeError: img should be PIL Image or Numpy array. Got <class 'tuple'>"
|
||||
|
||||
# error msg comes from ToTensor()
|
||||
|
|
|
@ -41,18 +41,18 @@ def test_case1():
|
|||
assert data.get_batch_size() == 2
|
||||
assert data.get_repeat_count() == 1
|
||||
data = data.repeat(10)
|
||||
assert data.get_dataset_size() == 6
|
||||
assert data.get_dataset_size() == 60
|
||||
assert data.get_batch_size() == 2
|
||||
assert data.get_repeat_count() == 10
|
||||
data = data.project(["new_column"])
|
||||
assert data.get_dataset_size() == 6
|
||||
assert data.get_dataset_size() == 60
|
||||
assert data.get_batch_size() == 2
|
||||
assert data.get_repeat_count() == 10
|
||||
|
||||
data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10)
|
||||
|
||||
data1 = data.zip(data2)
|
||||
assert data1.get_dataset_size() == 6
|
||||
assert data1.get_dataset_size() == 60
|
||||
|
||||
|
||||
def test_case2():
|
||||
|
@ -65,14 +65,14 @@ def test_case2():
|
|||
data = data.rename("col_sint64", "new_column")
|
||||
assert data.get_dataset_size() == 3
|
||||
data = data.repeat(10)
|
||||
assert data.get_dataset_size() == 3
|
||||
assert data.get_dataset_size() == 30
|
||||
data = data.project(["new_column"])
|
||||
assert data.get_dataset_size() == 3
|
||||
assert data.get_dataset_size() == 30
|
||||
|
||||
data2 = ds.TFRecordDataset(FILES, num_samples=6).batch(2).repeat(10)
|
||||
|
||||
data1 = data.zip(data2)
|
||||
assert data1.get_dataset_size() == 3
|
||||
assert data1.get_dataset_size() == 30
|
||||
|
||||
|
||||
def test_case3():
|
||||
|
@ -94,11 +94,11 @@ def test_case4():
|
|||
data2 = data2.shuffle(100)
|
||||
assert data2.get_dataset_size() == 6
|
||||
data2 = data2.repeat(3)
|
||||
assert data2.get_dataset_size() == 6
|
||||
assert data2.get_dataset_size() == 18
|
||||
|
||||
data3 = ds.zip((data1, data2))
|
||||
|
||||
assert data3.get_dataset_size() == 6
|
||||
assert data3.get_dataset_size() == 18
|
||||
|
||||
|
||||
def test_case5():
|
||||
|
|
|
@ -73,7 +73,7 @@ def test_iterator_weak_ref():
|
|||
|
||||
_cleanup()
|
||||
with pytest.raises(AttributeError) as info:
|
||||
itr2.get_next()
|
||||
itr2.__next__()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
|
||||
del itr1
|
||||
|
|
|
@ -251,6 +251,49 @@ def test_nested_repeat11():
|
|||
|
||||
assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3
|
||||
|
||||
def test_repeat_count1():
|
||||
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is {}".format(data1_size))
|
||||
batch_size = 2
|
||||
repeat_count = 4
|
||||
resize_height, resize_width = 32, 32
|
||||
decode_op = vision.Decode()
|
||||
resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR)
|
||||
data1 = data1.map(input_columns=["image"], operations=decode_op)
|
||||
data1 = data1.map(input_columns=["image"], operations=resize_op)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
data1 = data1.batch(batch_size, drop_remainder=False)
|
||||
dataset_size = data1.get_dataset_size()
|
||||
logger.info("dataset repeat then batch's size is {}".format(dataset_size))
|
||||
num1_iter = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num1_iter += 1
|
||||
|
||||
assert data1_size == 3
|
||||
assert dataset_size == num1_iter == 6
|
||||
|
||||
def test_repeat_count2():
|
||||
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is {}".format(data1_size))
|
||||
batch_size = 2
|
||||
repeat_count = 4
|
||||
resize_height, resize_width = 32, 32
|
||||
decode_op = vision.Decode()
|
||||
resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR)
|
||||
data1 = data1.map(input_columns=["image"], operations=decode_op)
|
||||
data1 = data1.map(input_columns=["image"], operations=resize_op)
|
||||
data1 = data1.batch(batch_size, drop_remainder=False)
|
||||
data1 = data1.repeat(repeat_count)
|
||||
dataset_size = data1.get_dataset_size()
|
||||
logger.info("dataset batch then repeat's size is {}".format(dataset_size))
|
||||
num1_iter = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num1_iter += 1
|
||||
|
||||
assert data1_size == 3
|
||||
assert dataset_size == num1_iter == 8
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tf_repeat_01()
|
||||
|
@ -268,3 +311,5 @@ if __name__ == "__main__":
|
|||
test_nested_repeat9()
|
||||
test_nested_repeat10()
|
||||
test_nested_repeat11()
|
||||
test_repeat_count1()
|
||||
test_repeat_count2()
|
||||
|
|
|
@ -252,14 +252,14 @@ def test_zip_exception_06():
|
|||
|
||||
if __name__ == '__main__':
|
||||
test_zip_01()
|
||||
test_zip_02()
|
||||
test_zip_03()
|
||||
test_zip_04()
|
||||
test_zip_05()
|
||||
test_zip_06()
|
||||
test_zip_exception_01()
|
||||
test_zip_exception_02()
|
||||
test_zip_exception_03()
|
||||
test_zip_exception_04()
|
||||
test_zip_exception_05()
|
||||
test_zip_exception_06()
|
||||
#test_zip_02()
|
||||
#test_zip_03()
|
||||
#test_zip_04()
|
||||
#test_zip_05()
|
||||
#test_zip_06()
|
||||
#test_zip_exception_01()
|
||||
#test_zip_exception_02()
|
||||
#test_zip_exception_03()
|
||||
#test_zip_exception_04()
|
||||
#test_zip_exception_05()
|
||||
#test_zip_exception_06()
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -274,6 +274,9 @@ class DatasetLenet():
|
|||
def get_repeat_count(self):
|
||||
return 1
|
||||
|
||||
def create_tuple_iterator(self):
|
||||
return self
|
||||
|
||||
|
||||
def test_train_32k_8p(batch_size=32, num_classes=32768):
|
||||
dev_num = 8
|
||||
|
|
|
@ -61,6 +61,9 @@ class DatasetLenet():
|
|||
def get_repeat_count(self):
|
||||
return 1
|
||||
|
||||
def create_tuple_iterator(self):
|
||||
return self
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
|
|
@ -58,6 +58,9 @@ class Dataset():
|
|||
def get_repeat_count(self):
|
||||
return 1
|
||||
|
||||
def create_tuple_iterator(self):
|
||||
return self
|
||||
|
||||
|
||||
class GatherV2(_Loss):
|
||||
def __init__(self, index_dim, strategy, index_size=16):
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test dataset helper."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.dataset_helper import DatasetHelper
|
||||
from ....dataset_mock import MindData
|
||||
|
||||
|
||||
def get_dataset(batch_size=1):
|
||||
dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32)
|
||||
dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1),
|
||||
(batch_size, 20), (batch_size, 20), (batch_size, 20))
|
||||
|
||||
dataset = MindData(size=2, batch_size=batch_size, np_types=dataset_types,
|
||||
output_shapes=dataset_shapes, input_indexs=(0, 1))
|
||||
return dataset
|
||||
|
||||
|
||||
def test_dataset_helper_dataset_sink_mode_str():
|
||||
dataset = get_dataset(32)
|
||||
with pytest.raises(TypeError):
|
||||
DatasetHelper(dataset, dataset_sink_mode="True")
|
||||
|
||||
|
||||
def test_dataset_helper_dataset_sink_mode_int():
|
||||
dataset = get_dataset(32)
|
||||
with pytest.raises(TypeError):
|
||||
DatasetHelper(dataset, dataset_sink_mode=1)
|
||||
|
||||
|
||||
def test_dataset_helper_sink_size_bool():
|
||||
dataset = get_dataset(32)
|
||||
with pytest.raises(TypeError):
|
||||
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=True)
|
||||
|
||||
|
||||
def test_dataset_helper_sink_size_float():
|
||||
dataset = get_dataset(32)
|
||||
with pytest.raises(TypeError):
|
||||
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=1.0)
|
||||
|
||||
|
||||
def test_dataset_helper_sink_size_negative():
|
||||
dataset = get_dataset(32)
|
||||
with pytest.raises(ValueError):
|
||||
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=-2)
|
||||
|
||||
|
||||
def test_dataset_iter_normal():
|
||||
dataset = get_dataset(32)
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=False)
|
||||
count = 0
|
||||
for _ in range(2):
|
||||
for _ in dataset_helper:
|
||||
count += 1
|
||||
dataset.reset()
|
||||
assert count == 6
|
||||
|
||||
|
||||
@pytest.mark.skipif('not context.get_context("enable_ge")')
|
||||
def test_dataset_iter_ge():
|
||||
init()
|
||||
dataset = get_dataset(32)
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
||||
count = 0
|
||||
for _ in range(2):
|
||||
for _ in dataset_helper:
|
||||
count += 1
|
||||
assert count == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif('context.get_context("enable_ge")')
|
||||
def test_dataset_iter_ms_loop_sink():
|
||||
init()
|
||||
context.set_context(enable_loop_sink=True)
|
||||
dataset = get_dataset(32)
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
||||
count = 0
|
||||
for _ in range(2):
|
||||
for inputs in dataset_helper:
|
||||
count += 1
|
||||
assert inputs == tuple()
|
||||
assert count == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif('context.get_context("enable_ge")')
|
||||
def test_dataset_iter_ms():
|
||||
init()
|
||||
context.set_context(enable_loop_sink=False)
|
||||
dataset = get_dataset(32)
|
||||
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
Loading…
Reference in New Issue