forked from mindspore-Ecosystem/mindspore
Add temp fix gor generator Op when num_epochs=-1
This commit is contained in:
parent
a75665ce49
commit
bcf913fb0d
|
@ -51,7 +51,15 @@ void EpochCtrlOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Call the super class for displaying any common detailed info
|
||||
PipelineOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_;
|
||||
out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_
|
||||
<< "\nLeaf Nodes in execution path:";
|
||||
if (!eoe_ops_.empty()) {
|
||||
for (size_t i = 0; i < eoe_ops_.size(); i++) {
|
||||
out << "\n Operator: " << eoe_ops_[i]->id();
|
||||
}
|
||||
} else {
|
||||
out << " None.";
|
||||
}
|
||||
out << "\n\n";
|
||||
}
|
||||
}
|
||||
|
@ -86,6 +94,13 @@ Status EpochCtrlOp::EoeReceived(int32_t worker_id) {
|
|||
// This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it.
|
||||
state_ = OpState::kDeOpIdle;
|
||||
|
||||
if (repeat_count_ != num_repeats_) {
|
||||
for (auto &eoe_op : eoe_ops_) {
|
||||
MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id();
|
||||
RETURN_IF_NOT_OK(eoe_op->Reset());
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -62,7 +62,15 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Call the super class for displaying any common detailed info
|
||||
PipelineOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_;
|
||||
out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_
|
||||
<< "\nLeaf Nodes in execution path:";
|
||||
if (!eoe_ops_.empty()) {
|
||||
for (size_t i = 0; i < eoe_ops_.size(); i++) {
|
||||
out << "\n Operator: " << eoe_ops_[i]->id();
|
||||
}
|
||||
} else {
|
||||
out << " None.";
|
||||
}
|
||||
out << "\n\n";
|
||||
}
|
||||
}
|
||||
|
@ -107,9 +115,17 @@ Status RepeatOp::EoeReceived(int32_t worker_id) {
|
|||
if (repeat_count_ == num_repeats_) {
|
||||
repeat_count_ = 0;
|
||||
state_ = OpState::kDeOpIdle;
|
||||
return Status::OK();
|
||||
} else {
|
||||
state_ = OpState::kDeOpRunning;
|
||||
}
|
||||
|
||||
// Invoke a reset against the eoe nodes only.
|
||||
for (auto &eoe_op : eoe_ops_) {
|
||||
MS_LOG(DEBUG) << "Repeat operator sending reset to operator: " << eoe_op->id();
|
||||
RETURN_IF_NOT_OK(eoe_op->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -138,6 +154,19 @@ int32_t RepeatOp::num_consumers() const {
|
|||
}
|
||||
}
|
||||
|
||||
// Drive reset actions if needed
|
||||
Status RepeatOp::Reset() {
|
||||
// If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op.
|
||||
// In that case, we now have to bounce the reset down to our own eoe ops.
|
||||
MS_LOG(DEBUG) << "Repeat operator " << operator_id_ << " got reset.";
|
||||
for (auto &eoe_op : eoe_ops_) {
|
||||
MS_LOG(DEBUG) << "Nested repeat operator bouncing a reset to operator: " << eoe_op->id();
|
||||
RETURN_IF_NOT_OK(eoe_op->Reset());
|
||||
}
|
||||
state_ = OpState::kDeOpRunning;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t RepeatOp::num_producers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0.";
|
||||
|
|
|
@ -129,6 +129,16 @@ class RepeatOp : public PipelineOp {
|
|||
/// \return The number of repeats that the user requested
|
||||
int32_t num_repeats() { return num_repeats_; }
|
||||
|
||||
/// \brief reset Op
|
||||
/// \@return Status - The error code return
|
||||
Status Reset() override;
|
||||
|
||||
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
|
||||
// \param[in] eoe_op The input leaf/eoe operator to add to the list
|
||||
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat.
|
||||
|
||||
protected:
|
||||
// The number of repeats that the user requested.
|
||||
// Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class.
|
||||
|
|
|
@ -186,6 +186,7 @@ Status GeneratorOp::FillBuffer(TensorQTable *tt) {
|
|||
Status GeneratorOp::operator()() {
|
||||
// Handshake with TaskManager to synchronize thread creation
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks()));
|
||||
std::unique_ptr<DataBuffer> fetched_buffer;
|
||||
bool eof = false;
|
||||
while (!eof) {
|
||||
|
@ -227,8 +228,17 @@ Status GeneratorOp::operator()() {
|
|||
MS_LOG(DEBUG) << "Generator operator main execution loop complete.";
|
||||
eof = true;
|
||||
} else {
|
||||
// Self-reset to start a new iteration
|
||||
RETURN_IF_NOT_OK(Reset());
|
||||
// Waiting for repeatOp to start new epoch
|
||||
// If Reset() is called first by repeat op, this wait() will return right away.
|
||||
// If Reset() is not called yet, this wait() will block until reset.
|
||||
if (this->op_total_repeats() < 0) {
|
||||
RETURN_IF_NOT_OK(wp_.Wait());
|
||||
// Clear the status of the wait post
|
||||
wp_.Clear();
|
||||
} else {
|
||||
// Self-reset to start a new iteration
|
||||
RETURN_IF_NOT_OK(Reset());
|
||||
}
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
|
@ -240,6 +250,10 @@ Status GeneratorOp::Reset() {
|
|||
// Reset Op state
|
||||
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
||||
RETURN_IF_NOT_OK(this->Init());
|
||||
if (this->op_total_repeats() < 0) {
|
||||
// Wake up master thread
|
||||
wp_.Set();
|
||||
}
|
||||
return Status(StatusCode::kOK, "GeneratorOp Reset Succeed");
|
||||
}
|
||||
|
||||
|
|
|
@ -144,6 +144,8 @@ class GeneratorOp : public PipelineOp {
|
|||
py::object generator_;
|
||||
int32_t buffer_id_;
|
||||
|
||||
WaitPost wp_;
|
||||
|
||||
Status Init();
|
||||
|
||||
void Dealloc() noexcept;
|
||||
|
|
|
@ -22,15 +22,31 @@
|
|||
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
RepeatPass::RepeatPass()
|
||||
: num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {}
|
||||
: is_repeated_(false),
|
||||
nested_repeats_(0),
|
||||
num_repeats_(1),
|
||||
num_epochs_(1),
|
||||
is_merge_(false),
|
||||
is_cached_(false),
|
||||
cache_lookup_(nullptr) {}
|
||||
|
||||
// Identifies the subtree below this node as being in a repeated path of the tree.
|
||||
Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
||||
// Create a new stack for eoe operators and push onto our stack of stacks.
|
||||
std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>();
|
||||
eoe_op_stacks_.push(std::move(new_stack));
|
||||
// If we are already repeated, then this is a nested repeat.
|
||||
if (is_repeated_) {
|
||||
nested_repeats_++;
|
||||
}
|
||||
is_repeated_ = true;
|
||||
|
||||
// If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_.
|
||||
// Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely.
|
||||
if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) {
|
||||
|
@ -58,7 +74,9 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modifie
|
|||
// that RepeatOp does. However, epoch control is actually simpler because it can
|
||||
// only exist as the root node so it doesn't need all the nested code.
|
||||
// Create a new stack for eoe operators and push onto our stack of stacks.
|
||||
|
||||
std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>();
|
||||
eoe_op_stacks_.push(std::move(new_stack));
|
||||
is_repeated_ = true;
|
||||
// Get the total number of epochs from the EpochCtrlOp parameter
|
||||
num_epochs_ = node->num_repeats();
|
||||
// Every node below this EpochCtrlOp should be repeated for num_epochs_ times.
|
||||
|
@ -85,6 +103,22 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
|||
|
||||
// Hooks up any identified eoe nodes under this repeat.
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
||||
// Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking
|
||||
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
|
||||
|
||||
while (leaf_op != nullptr) {
|
||||
node->AddToEoeList(leaf_op);
|
||||
leaf_op = PopFromEOEOpStack();
|
||||
}
|
||||
|
||||
// At this point, we are done with the save area stack. It's a unique pointer to an empty stack
|
||||
// at this time, so we can pop it to get rid of it.
|
||||
op_stack *current_stack = eoe_op_stacks_.top().get();
|
||||
if (!current_stack->empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!");
|
||||
}
|
||||
eoe_op_stacks_.pop();
|
||||
|
||||
// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
|
||||
// and set its total repeats. It is important that the op is removed from the save area,
|
||||
// because the merge op above us may also take action on it later for a different case when
|
||||
|
@ -95,6 +129,18 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|||
cache_lookup_.reset();
|
||||
}
|
||||
|
||||
// If we are a nested repeat, then we add ourself to the repeat stack for the next one above us.
|
||||
// A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree.
|
||||
if (nested_repeats_ > 0) {
|
||||
AddToEOEOpStack(node);
|
||||
nested_repeats_--;
|
||||
} else {
|
||||
// If we are not nested, or we were the top-most repeat, now we clear the flag
|
||||
if (nested_repeats_ != 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!");
|
||||
}
|
||||
is_repeated_ = false;
|
||||
}
|
||||
if (is_cached_) {
|
||||
AddToCachedOpStack(node);
|
||||
}
|
||||
|
@ -110,6 +156,13 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|||
|
||||
// Hooks up any identified eoe nodes under this repeat.
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
|
||||
// Pop the leaf ops from the save-area stack and add them to the eoe node tracking
|
||||
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
|
||||
while (leaf_op != nullptr) {
|
||||
node->AddToEoeList(leaf_op);
|
||||
leaf_op = PopFromEOEOpStack();
|
||||
}
|
||||
is_repeated_ = false;
|
||||
node->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
// We finish the walk of this EpochCtrl's descendent nodes.
|
||||
|
@ -138,6 +191,23 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
|
||||
// If we are in a repeat path, then set our repeated flag
|
||||
if (is_repeated_) {
|
||||
// if infinite repeat save ourself in a stack for the repeat operator above us
|
||||
if (num_repeats_ < 0) {
|
||||
AddToEOEOpStack(node);
|
||||
}
|
||||
}
|
||||
// If we are under a cache op, then save ourselves to the cached op stack.
|
||||
if (is_cached_) {
|
||||
AddToCachedOpStack(node);
|
||||
}
|
||||
// Set total repeats and total epochs for the node
|
||||
node->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
return Status::OK();
|
||||
}
|
||||
// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
|
||||
// for use with a controlling repeat above it.
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
|
||||
|
@ -190,6 +260,23 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Adds an operator to the eoe operator stack save area
|
||||
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) {
|
||||
op_stack *current_stack = eoe_op_stacks_.top().get();
|
||||
current_stack->push(dataset_op);
|
||||
}
|
||||
|
||||
// Pops an operator from the eoe operator stack save area
|
||||
std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() {
|
||||
std::shared_ptr<DatasetOp> top_op = nullptr;
|
||||
op_stack *current_stack = eoe_op_stacks_.top().get();
|
||||
if (current_stack != nullptr && !current_stack->empty()) {
|
||||
top_op = current_stack->top();
|
||||
current_stack->pop();
|
||||
}
|
||||
return top_op;
|
||||
}
|
||||
|
||||
// Adds an operator to the cached operator stack save area
|
||||
void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); }
|
||||
|
||||
|
|
|
@ -98,6 +98,12 @@ class RepeatPass : public NodePass {
|
|||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Special case for GeneratorOp
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
|
||||
|
||||
/// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
|
||||
/// for use with a controlling repeat above it.
|
||||
/// \param[in] node The node being visited
|
||||
|
@ -106,6 +112,19 @@ class RepeatPass : public NodePass {
|
|||
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override;
|
||||
|
||||
private:
|
||||
/// \brief Adds an operator to the eoe operator stack save area
|
||||
/// \param op - The dataset op to work add to eoe stack
|
||||
/// \return Status - The error code return
|
||||
void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op);
|
||||
|
||||
/// \brief Pops an operator from the eoe operator stack save area
|
||||
/// \return shared_ptr to the popped operator
|
||||
std::shared_ptr<DatasetOp> PopFromEOEOpStack();
|
||||
|
||||
bool is_repeated_; // T/F if we are processing under a repeat
|
||||
int32_t nested_repeats_; // A counter for nested repeats
|
||||
std::stack<std::unique_ptr<op_stack>> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting)
|
||||
|
||||
/// \brief Adds an operator to the cached operator stack save area
|
||||
/// \param op - The dataset op to work add to cached stack
|
||||
/// \return Status - The error code return
|
||||
|
|
Loading…
Reference in New Issue