forked from mindspore-Ecosystem/mindspore
!3346 Maintain epoch/repeat count for ops
Merge pull request !3346 from lixiachen/repeat_rework
This commit is contained in:
commit
e19d382473
|
@ -89,13 +89,14 @@ Status CacheBase::FetchSamplesToWorkers() {
|
|||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
// If repeat but the not last repeat, wait for reset.
|
||||
if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (!IsLastIteration()) {
|
||||
MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt;
|
||||
RETURN_IF_NOT_OK(epoch_sync_.Wait());
|
||||
} else {
|
||||
// We can break out from the loop.
|
||||
break;
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
} while (true);
|
||||
// Flow the eof before exit
|
||||
RETURN_IF_NOT_OK(
|
||||
|
|
|
@ -292,7 +292,7 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) {
|
|||
Status CacheMergeOp::EoeReceived(int32_t worker_id) {
|
||||
// If we are in a repeat path, send the eoe up.
|
||||
// Otherwise ignore it.
|
||||
if (BitTest(op_ctrl_flags_, kDeOpRepeated)) {
|
||||
if (op_total_repeats_ > 1) {
|
||||
return DatasetOp::EoeReceived(worker_id);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -304,7 +304,7 @@ Status CacheMergeOp::EofReceived(int32_t worker_id) {
|
|||
// getting an eoe. However, the logic demands that all epochs close with an eoe first before eof.
|
||||
// Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class
|
||||
// provides that for us.
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated)) {
|
||||
if (op_total_repeats_ == 1) {
|
||||
MS_LOG(DEBUG) << "Cache merge sending eoe";
|
||||
RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id));
|
||||
}
|
||||
|
|
|
@ -85,6 +85,10 @@ Status CacheOp::operator()() {
|
|||
TaskManager::FindMe()->Post();
|
||||
// Wait for the workers to finish caching the rows.
|
||||
RETURN_IF_NOT_OK(WaitForCachingAllRows());
|
||||
// Current repeats and current epochs may have increased when caching all rows with DatasetOp::GetNextInput.
|
||||
// But they shouldn't be increased because now cache op is starting to act as a leaf and its epoch hasn't started.
|
||||
op_current_repeats_ = 0;
|
||||
op_current_epochs_ = 0;
|
||||
RETURN_IF_NOT_OK(FetchSamplesToWorkers());
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -85,6 +85,7 @@ Status ConcatOp::operator()() {
|
|||
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_,
|
||||
"Something went wrong, eof count does not match the number of children.");
|
||||
|
|
|
@ -42,7 +42,10 @@ DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler
|
|||
operator_id_(kInvalidOperatorId),
|
||||
tree_(nullptr),
|
||||
state_(OpState::kDeOpIdle),
|
||||
op_ctrl_flags_(kDeOpNone),
|
||||
op_total_repeats_(kInfiniteRepeat),
|
||||
op_num_repeats_per_epoch_(kInfiniteRepeat),
|
||||
op_current_repeats_(0),
|
||||
op_current_epochs_(0),
|
||||
out_connector_(nullptr) {
|
||||
// The operator starts out with an invalid operator id. The only way to
|
||||
// get it out of invalid state is to assign the operator to an execution tree.
|
||||
|
@ -237,8 +240,8 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const {
|
|||
for (size_t i = 0; i < parent_.size(); i++) {
|
||||
out << "\n Parent[" << i << "] id: " << parent_[i]->id();
|
||||
}
|
||||
out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex
|
||||
<< std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' ');
|
||||
out << "\nConnector queue size : " << oc_queue_size_ << "\nTotal repeats : " << op_total_repeats_
|
||||
<< "\nNumber repeats per epoch : " << op_num_repeats_per_epoch_;
|
||||
if (sampler_) {
|
||||
sampler_->Print(out, show_all);
|
||||
}
|
||||
|
@ -265,6 +268,7 @@ Status DatasetOp::GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
|
|||
RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id));
|
||||
// Loop until non EOE is received
|
||||
while (buf->eoe()) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
RETURN_IF_NOT_OK(EoeReceived(worker_id));
|
||||
if (state_ == OpState::kDeOpIdle) {
|
||||
*p_buffer = std::move(buf);
|
||||
|
@ -408,5 +412,10 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) {
|
|||
uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length());
|
||||
return cache_crc;
|
||||
}
|
||||
|
||||
void DatasetOp::UpdateRepeatAndEpochCounter() {
|
||||
op_current_repeats_++;
|
||||
if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -70,13 +70,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
|
||||
public:
|
||||
static constexpr int32_t kInvalidOperatorId = -1;
|
||||
|
||||
// Operator control flags
|
||||
enum OpControlFlags {
|
||||
kDeOpNone = 0,
|
||||
kDeOpRepeated = 1, // Operator is a node in a repeat path
|
||||
kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop
|
||||
};
|
||||
static constexpr int32_t kInfiniteRepeat = -1;
|
||||
|
||||
// Flags that control operator runtime behaviours
|
||||
enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated };
|
||||
|
@ -238,13 +232,23 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \return T/F if this is an inlined operator
|
||||
bool inlined() const { return (oc_queue_size_ == 0); }
|
||||
|
||||
/// \brief Setter function
|
||||
/// \return Sets the control flags
|
||||
void set_control_flag(uint64_t flag) { BitSet(&op_ctrl_flags_, flag); }
|
||||
/// \brief Setter function, set the number of total repeats for the operator
|
||||
void set_total_repeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; }
|
||||
|
||||
/// \brief Setter function
|
||||
/// \return Sets the control flags
|
||||
void ClearControlFlag(uint64_t flag) { BitClear(&op_ctrl_flags_, flag); }
|
||||
/// \brief Setter function, set the number of repeats per epoch for the operator
|
||||
void set_num_repeats_per_epoch(int32_t num_repeats_per_epoch) { op_num_repeats_per_epoch_ = num_repeats_per_epoch; }
|
||||
|
||||
/// \brief Getter function
|
||||
/// \return The number of required repeats for the operator
|
||||
int32_t op_total_repeats() { return op_total_repeats_; }
|
||||
|
||||
/// \brief Getter function
|
||||
/// \return The number of required epochs for the operator
|
||||
int32_t op_total_epochs() { return op_total_repeats_ / op_num_repeats_per_epoch_; }
|
||||
|
||||
/// \brief Getter function
|
||||
/// \return The number of repeats per epoch for the operator
|
||||
int32_t op_num_repeats_per_epoch() { return op_num_repeats_per_epoch_; }
|
||||
|
||||
/// \brief Register the internal worker connectors. No op unless it is a parallel op
|
||||
/// \return Status
|
||||
|
@ -350,6 +354,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \return boolean returns true if it's a leaf
|
||||
bool IsLeaf() { return (child_.empty()); }
|
||||
|
||||
/// Checks if an operator has reached its last iteration
|
||||
/// \return boolean returns true if it's last iteration
|
||||
bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; }
|
||||
|
||||
protected:
|
||||
/// \brief Removes a parent operator from this operator
|
||||
/// \notes External callers do not have access to this function
|
||||
|
@ -368,6 +376,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \return - Status
|
||||
virtual Status ComputeColMap();
|
||||
|
||||
/// Increase op_current_repeats_ by 1 when one repeat finished.
|
||||
/// If this repeat happen to be the last repeat in the current epoch, also increase op_current_epochs_ by 1.
|
||||
void UpdateRepeatAndEpochCounter();
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes
|
||||
std::vector<DatasetOp *> parent_; // Parent nodes. No ownership
|
||||
std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler
|
||||
|
@ -375,7 +387,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
int32_t operator_id_; // Generated id for the node
|
||||
ExecutionTree *tree_; // Back pointer to our tree.
|
||||
OpState state_; // The state of the operator, Running, Idle, Terminated
|
||||
uint32_t op_ctrl_flags_; // Flags for the operator
|
||||
int32_t op_total_repeats_; // Required number of repeats for the operator
|
||||
int32_t op_num_repeats_per_epoch_; // Total number of repeats per epoch for the operator
|
||||
int32_t op_current_repeats_; // Current number of repeats the operator has handled
|
||||
int32_t op_current_epochs_; // Current number of epochs the operator has handled
|
||||
std::unique_ptr<DbConnector> out_connector_; // Output Connector
|
||||
std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name
|
||||
std::mutex column_name_map_mutex_; // For protecting shared access to the column map
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace dataset {
|
|||
// The builder "build" method creates the final object.
|
||||
Status EpochCtrlOp::Builder::Build(std::shared_ptr<EpochCtrlOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<EpochCtrlOp>(build_max_repeats_);
|
||||
*ptr = std::make_shared<EpochCtrlOp>(build_num_repeats_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -46,12 +46,12 @@ void EpochCtrlOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Call the super class for displaying any common 1-liner info
|
||||
PipelineOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << " [epochs: " << max_repeats_ << "]\n";
|
||||
out << " [epochs: " << num_repeats_ << "]\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
PipelineOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << max_repeats_
|
||||
out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_
|
||||
<< "\nLeaf Nodes in execution path:";
|
||||
if (!eoe_ops_.empty()) {
|
||||
for (size_t i = 0; i < eoe_ops_.size(); i++) {
|
||||
|
@ -86,24 +86,15 @@ Status EpochCtrlOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t
|
|||
}
|
||||
|
||||
Status EpochCtrlOp::EoeReceived(int32_t worker_id) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
repeat_count_++;
|
||||
MS_LOG(DEBUG) << "Epoch Control operator received end of epoch. Epoch count is now: " << repeat_count_
|
||||
<< ". Repeated: " << BitTest(op_ctrl_flags_, kDeOpRepeated) << ". Max epochs: " << max_repeats_;
|
||||
|
||||
// If we've reached the requested epoch count, then flag the leaf nodes
|
||||
// to tell them they've got one more epoch to perform. When they reach the end
|
||||
// of the last epoch, they quit rather than loop again.
|
||||
if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) {
|
||||
for (auto &eoe_op : eoe_ops_) {
|
||||
MS_LOG(DEBUG) << "EpochCtrl setting last repeat for eoe_op: " << eoe_op->id();
|
||||
eoe_op->set_control_flag(kDeOpLastRepeat);
|
||||
}
|
||||
}
|
||||
<< ". Max epochs: " << num_repeats_;
|
||||
|
||||
// This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it.
|
||||
state_ = OpState::kDeOpIdle;
|
||||
|
||||
if (repeat_count_ != max_repeats_) {
|
||||
if (repeat_count_ != num_repeats_) {
|
||||
for (auto &eoe_op : eoe_ops_) {
|
||||
MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id();
|
||||
RETURN_IF_NOT_OK(eoe_op->Reset());
|
||||
|
|
|
@ -117,6 +117,7 @@ Status FilterOp::WorkerEntry(int32_t worker_id) {
|
|||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id));
|
||||
if (in_buffer->eoe()) {
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe));
|
||||
UpdateRepeatAndEpochCounter();
|
||||
continue;
|
||||
} else if (in_buffer->eof()) {
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof));
|
||||
|
|
|
@ -231,6 +231,7 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
|
|||
// Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work
|
||||
// with Performance Mode design.
|
||||
if (in_buffer->eoe()) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
// Calling base class EoeReceived to forward eoe buffer.
|
||||
RETURN_IF_NOT_OK(EoeReceived(worker_id));
|
||||
// Fetch next data buffer and map job list
|
||||
|
|
|
@ -74,6 +74,9 @@ Status ProjectOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t w
|
|||
if (!((*p_buffer)->eoe()) && !((*p_buffer)->eof())) {
|
||||
RETURN_IF_NOT_OK(Project(p_buffer));
|
||||
}
|
||||
if ((*p_buffer)->eoe()) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -28,10 +28,10 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Builder constructor. Creates the builder object.
|
||||
RepeatOp::Builder::Builder(int32_t count) : build_max_repeats_(count) {}
|
||||
RepeatOp::Builder::Builder(int32_t count) : build_num_repeats_(count) {}
|
||||
|
||||
Status RepeatOp::Builder::SanityCheck() const {
|
||||
if (build_max_repeats_ < kInfiniteRepeat || build_max_repeats_ == 0) {
|
||||
if (build_num_repeats_ < kInfiniteRepeat || build_num_repeats_ == 0) {
|
||||
std::string err_msg("Repeat count must be > 0 or -1.");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
@ -41,12 +41,12 @@ Status RepeatOp::Builder::SanityCheck() const {
|
|||
// The builder "build" method creates the final object.
|
||||
Status RepeatOp::Builder::Build(std::shared_ptr<RepeatOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<RepeatOp>(build_max_repeats_);
|
||||
*ptr = std::make_shared<RepeatOp>(build_num_repeats_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor of the RepeatOp.
|
||||
RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), max_repeats_(count), repeat_count_(0) {}
|
||||
RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), num_repeats_(count), repeat_count_(0) {}
|
||||
|
||||
// Destructor
|
||||
RepeatOp::~RepeatOp() {}
|
||||
|
@ -57,12 +57,12 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Call the super class for displaying any common 1-liner info
|
||||
PipelineOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << " [repeats: " << max_repeats_ << "]\n";
|
||||
out << " [repeats: " << num_repeats_ << "]\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
PipelineOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_
|
||||
out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_
|
||||
<< "\nLeaf Nodes in execution path:";
|
||||
if (!eoe_ops_.empty()) {
|
||||
for (size_t i = 0; i < eoe_ops_.size(); i++) {
|
||||
|
@ -107,22 +107,13 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
|
|||
|
||||
// Base-class override for handling cases when an eoe is received.
|
||||
Status RepeatOp::EoeReceived(int32_t worker_id) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
|
||||
repeat_count_++;
|
||||
MS_LOG(DEBUG) << "Repeat operator (" << operator_id_
|
||||
<< ") end of epoch message received. Repeat count is now: " << repeat_count_ << ".";
|
||||
bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated);
|
||||
bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat);
|
||||
// If we've reached the requested repeat count, then flag the eoe nodes
|
||||
// to tell them they've got one more epoch to perform. When they reach the end
|
||||
// of the last epoch, they quit rather than loop again. This happens in two cases:
|
||||
// 1- We are also repeated (by another repeat op) and we are at the last repetition. Or,
|
||||
// 2- We are not repeated
|
||||
if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) {
|
||||
for (auto &eoe_op : eoe_ops_) {
|
||||
eoe_op->set_control_flag(kDeOpLastRepeat);
|
||||
}
|
||||
}
|
||||
if (repeat_count_ == max_repeats_) {
|
||||
|
||||
if (repeat_count_ == num_repeats_) {
|
||||
repeat_count_ = 0;
|
||||
state_ = OpState::kDeOpIdle;
|
||||
return Status::OK();
|
||||
|
|
|
@ -26,8 +26,6 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
class RepeatOp : public PipelineOp {
|
||||
public:
|
||||
static constexpr int32_t kInfiniteRepeat = -1;
|
||||
|
||||
// The nested builder class inside of the RepeatOp is used to help manage all of the arguments
|
||||
// for constructing it. This repeat op is very simple though, so this builder is really just
|
||||
// provided for a consistent look and feel for creators of Dataset operators overall.
|
||||
|
@ -47,7 +45,7 @@ class RepeatOp : public PipelineOp {
|
|||
Status Build(std::shared_ptr<RepeatOp> *);
|
||||
|
||||
protected:
|
||||
int32_t build_max_repeats_;
|
||||
int32_t build_num_repeats_;
|
||||
|
||||
Status SanityCheck() const;
|
||||
};
|
||||
|
@ -131,13 +129,24 @@ class RepeatOp : public PipelineOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kRepeatOp; }
|
||||
|
||||
/// \brief Getter function
|
||||
/// \return The number of repeats that the user requested
|
||||
int32_t num_repeats() { return num_repeats_; }
|
||||
|
||||
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
|
||||
// \param[in] eoe_op The input leaf/eoe operator to add to the list
|
||||
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }
|
||||
|
||||
protected:
|
||||
int32_t max_repeats_; // The number of repeats that the user requested
|
||||
int32_t repeat_count_; // A counter for the current number of executed repeats
|
||||
// The number of repeats that the user requested.
|
||||
// Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class.
|
||||
// For example, for repeat1 op in pipeline tfreader -> repeat1(3) -> repeat2(2) -> epoch ctrl(4),
|
||||
// num_repeats_ = 3, op_total_repeats_ = 24, op_num_repeats_per_epoch_ = 6.
|
||||
int32_t num_repeats_;
|
||||
// A counter for the current number of executed repeats.
|
||||
// Note that repeat_count_ is different with op_current_repeats_ in the base DatasetOp class
|
||||
// because it counts the repeats in the current epoch, whereas op_current_repeats_ counts the global total repeats.
|
||||
int32_t repeat_count_;
|
||||
std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat.
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -293,7 +293,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
|
|||
RETURN_IF_NOT_OK(io_block_queues_[(buff_count++) % num_workers_]->Add(
|
||||
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
|
||||
}
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (IsLastIteration()) {
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
RETURN_IF_NOT_OK(
|
||||
|
@ -310,6 +310,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
|
|||
wp_.Clear();
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer));
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -120,7 +120,7 @@ Status CifarOp::operator()() {
|
|||
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
|
||||
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
|
||||
}
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (IsLastIteration()) {
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
RETURN_IF_NOT_OK(
|
||||
|
@ -137,6 +137,7 @@ Status CifarOp::operator()() {
|
|||
wp_.Clear();
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -271,13 +271,14 @@ Status ClueOp::operator()() {
|
|||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (IsLastIteration()) {
|
||||
finished_reading_dataset_ = true;
|
||||
NotifyToFillIOBlockQueue();
|
||||
} else {
|
||||
jagged_buffer_connector_->DoReset();
|
||||
buffer_id = 0;
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
|
||||
|
|
|
@ -167,7 +167,7 @@ Status CocoOp::operator()() {
|
|||
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
|
||||
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
|
||||
}
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (IsLastIteration()) {
|
||||
std::unique_ptr<IOBlock> eoe_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe);
|
||||
std::unique_ptr<IOBlock> eof_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof);
|
||||
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block)));
|
||||
|
@ -184,6 +184,7 @@ Status CocoOp::operator()() {
|
|||
wp_.Clear();
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -472,13 +472,14 @@ Status CsvOp::operator()() {
|
|||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (IsLastIteration()) {
|
||||
finished_reading_dataset_ = true;
|
||||
NotifyToFillIOBlockQueue();
|
||||
} else {
|
||||
jagged_buffer_connector_->DoReset();
|
||||
buffer_id = 0;
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
|
||||
|
|
|
@ -216,7 +216,7 @@ Status GeneratorOp::operator()() {
|
|||
MS_LOG(DEBUG) << "Generator operator sends out EOE.";
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (IsLastIteration()) {
|
||||
// If last repeat or not repeated, push out EOF and exit master loop
|
||||
MS_LOG(DEBUG) << "Generator operator sends out EOF.";
|
||||
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
||||
|
@ -231,6 +231,7 @@ Status GeneratorOp::operator()() {
|
|||
// Clear the status of the wait post
|
||||
wp_.Clear();
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
|
@ -151,7 +151,7 @@ Status ImageFolderOp::operator()() {
|
|||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(keys, IOBlock::kDeIoBlockNone)));
|
||||
}
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (IsLastIteration()) {
|
||||
std::unique_ptr<IOBlock> eoe_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe);
|
||||
std::unique_ptr<IOBlock> eof_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof);
|
||||
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block)));
|
||||
|
@ -168,6 +168,7 @@ Status ImageFolderOp::operator()() {
|
|||
wp_.Clear();
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -112,7 +112,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
|
|||
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
|
||||
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
|
||||
}
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (IsLastIteration()) {
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
RETURN_IF_NOT_OK(
|
||||
|
@ -129,6 +129,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
|
|||
wp_.Clear();
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer));
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -378,7 +378,7 @@ Status MindRecordOp::operator()() {
|
|||
RETURN_IF_NOT_OK(io_blk_queues_[buf_cnt_++ % num_workers_]->Add(
|
||||
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
|
||||
}
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (IsLastIteration()) {
|
||||
RETURN_IF_NOT_OK(
|
||||
io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
RETURN_IF_NOT_OK(
|
||||
|
@ -396,6 +396,7 @@ Status MindRecordOp::operator()() {
|
|||
RETURN_IF_NOT_OK(shard_reader_wait_post_.Wait());
|
||||
shard_reader_wait_post_.Clear();
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -111,7 +111,7 @@ Status MnistOp::operator()() {
|
|||
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
|
||||
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
|
||||
}
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (IsLastIteration()) {
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
RETURN_IF_NOT_OK(
|
||||
|
@ -128,6 +128,7 @@ Status MnistOp::operator()() {
|
|||
wp_.Clear();
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -219,7 +219,7 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) {
|
|||
all_out_.Wait();
|
||||
// If we are not in a repeat loop, or that was the last repeat already, then setup our exit
|
||||
// condition from the master loop.
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (IsLastIteration()) {
|
||||
*quitting = true;
|
||||
}
|
||||
|
||||
|
@ -229,6 +229,7 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) {
|
|||
if (last_guy_in) {
|
||||
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is the last one to sync. eoe sent as worker "
|
||||
<< eoe_worker_id_;
|
||||
UpdateRepeatAndEpochCounter();
|
||||
// Prepare for sync
|
||||
all_out_.Clear();
|
||||
// Always flow eoe at the end
|
||||
|
|
|
@ -419,13 +419,14 @@ Status TextFileOp::operator()() {
|
|||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (IsLastIteration()) {
|
||||
finished_reading_dataset_ = true;
|
||||
NotifyToFillIOBlockQueue();
|
||||
} else {
|
||||
jagged_buffer_connector_->DoReset();
|
||||
buffer_id = 0;
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
||||
|
|
|
@ -308,13 +308,14 @@ Status TFReaderOp::operator()() {
|
|||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (IsLastIteration()) {
|
||||
finished_reading_dataset_ = true;
|
||||
NotifyToFillIOBlockQueue();
|
||||
} else {
|
||||
jagged_buffer_connector_->DoReset();
|
||||
buffer_id = 0;
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
||||
|
|
|
@ -145,7 +145,7 @@ Status VOCOp::operator()() {
|
|||
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
|
||||
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
|
||||
}
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
if (IsLastIteration()) {
|
||||
std::unique_ptr<IOBlock> eoe_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe);
|
||||
std::unique_ptr<IOBlock> eof_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof);
|
||||
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block)));
|
||||
|
@ -162,6 +162,7 @@ Status VOCOp::operator()() {
|
|||
wp_.Clear();
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -82,6 +82,7 @@ Status TakeOp::operator()() {
|
|||
|
||||
// Loop until non EOE is received
|
||||
if (buf->eoe()) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
take_count_ = 0;
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
|
||||
|
|
|
@ -25,18 +25,44 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(false), cache_lookup_(nullptr) {}
|
||||
RepeatPass::RepeatPass()
|
||||
: is_repeated_(false),
|
||||
nested_repeats_(0),
|
||||
num_repeats_(1),
|
||||
num_epochs_(1),
|
||||
is_merge_(false),
|
||||
is_cached_(false),
|
||||
cache_lookup_(nullptr) {}
|
||||
|
||||
// Identifies the subtree below this node as being in a repeated path of the tree.
|
||||
Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
||||
// Create a new stack for eoe operators and push onto our stack of stacks.
|
||||
std::unique_ptr<eoe_op_stack> new_stack = std::make_unique<eoe_op_stack>();
|
||||
std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>();
|
||||
eoe_op_stacks_.push(std::move(new_stack));
|
||||
// If we are already repeated, then this is a nested repeat.
|
||||
if (is_repeated_) {
|
||||
nested_repeats_++;
|
||||
}
|
||||
is_repeated_ = true;
|
||||
|
||||
// If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_.
|
||||
// Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely.
|
||||
if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) {
|
||||
num_repeats_ = -num_repeats_;
|
||||
}
|
||||
// This RepeatOp and its descendent nodes should be repeated for another num_repeats() times.
|
||||
//
|
||||
// Consider this example:
|
||||
// tfreader --> map --> repeat(2) --> epoch ctrl(3)
|
||||
// num_repeats_ is originally 3, after this repeat(2), num_repeats_ becomes 6 (2*3),
|
||||
// meaning repeat op should be set to read 6 times (2*3), do does map op and tfreader op.
|
||||
//
|
||||
// Another example:
|
||||
// tfreader --> repeat1(3) --> map --> repeat2(2) --> epoch ctrl(4)
|
||||
// num_repeats_ is originally 4, after repeat2(2), num_repeats_ becomes 8 (2*4),
|
||||
// meaning repeat2 and map op should be set to read 8 times (2*4).
|
||||
// Then, after repeat1(3), num_repeats_ becomes 24 (3*2*4), meaning repeat1 and tfreader op should repeat 24 times.
|
||||
num_repeats_ *= node->num_repeats();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -46,9 +72,16 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modifie
|
|||
// that RepeatOp does. However, epoch control is actually simpler because it can
|
||||
// only exist as the root node so it doesn't need all the nested code.
|
||||
// Create a new stack for eoe operators and push onto our stack of stacks.
|
||||
std::unique_ptr<eoe_op_stack> new_stack = std::make_unique<eoe_op_stack>();
|
||||
std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>();
|
||||
eoe_op_stacks_.push(std::move(new_stack));
|
||||
is_repeated_ = true;
|
||||
// Get the total number of epochs from the EpochCtrlOp parameter
|
||||
num_epochs_ = node->num_repeats();
|
||||
// Every node below this EpochCtrlOp should be repeated for num_epochs_ times.
|
||||
// For example: tfreader --> epoch ctrl(3)
|
||||
// num_repeats_ is originally 1 (default initialization), after this epoch ctrl(3), num_repeats_ becomes 3 (1*3),
|
||||
// meaning epoch ctrl op should be set to read 3 times (1*3), so does tfreader op.
|
||||
num_repeats_ *= num_epochs_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -59,6 +92,13 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modifi
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Identifies the subtree below this node as being cached
|
||||
Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
// Turn on the flag that we're under a merge op
|
||||
is_cached_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Hooks up any identified eoe nodes under this repeat.
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
||||
// Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking
|
||||
|
@ -71,7 +111,7 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|||
|
||||
// At this point, we are done with the save area stack. It's a unique pointer to an empty stack
|
||||
// at this time, so we can pop it to get rid of it.
|
||||
eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
|
||||
op_stack *current_stack = eoe_op_stacks_.top().get();
|
||||
if (!current_stack->empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!");
|
||||
}
|
||||
|
@ -82,14 +122,14 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|||
// from the save area, because the merge op above us may also take action on it later for a different
|
||||
// case when there is no repeat in the merge leg.
|
||||
if (is_merge_ && cache_lookup_) {
|
||||
cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||
cache_lookup_->set_total_repeats(num_repeats_);
|
||||
cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
node->AddToEoeList(std::move(cache_lookup_));
|
||||
}
|
||||
|
||||
// If we are a nested repeat, then we add ourself to the repeat stack for the next one above us.
|
||||
// A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree.
|
||||
if (nested_repeats_ > 0) {
|
||||
node->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||
AddToEOEOpStack(node);
|
||||
nested_repeats_--;
|
||||
} else {
|
||||
|
@ -99,7 +139,16 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|||
}
|
||||
is_repeated_ = false;
|
||||
}
|
||||
|
||||
if (is_cached_) {
|
||||
AddToCachedOpStack(node);
|
||||
}
|
||||
node->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
// We finish the walk of this RepeatOp's descendent nodes.
|
||||
// The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n.
|
||||
// But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode,
|
||||
// so we devide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp.
|
||||
num_repeats_ /= node->num_repeats();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -112,13 +161,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified)
|
|||
leaf_op = PopFromEOEOpStack();
|
||||
}
|
||||
is_repeated_ = false;
|
||||
node->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
// We finish the walk of this EpochCtrl's descendent nodes.
|
||||
num_repeats_ /= node->num_repeats();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// CacheOp removes previous leaf ops and replaces them with itself
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
is_cached_ = false;
|
||||
if (is_repeated_) {
|
||||
node->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||
// if we are a cache within a repeat path of the tree, then there will be
|
||||
// eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the
|
||||
// repeat or epoch ctrl operators can work with them for repeat activity during runtime.
|
||||
|
@ -130,13 +183,23 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
|||
// the repeating behaviours shall be invoked against the cache op.
|
||||
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
|
||||
while (leaf_op != nullptr) {
|
||||
leaf_op->ClearControlFlag(DatasetOp::kDeOpLastRepeat);
|
||||
leaf_op->ClearControlFlag(DatasetOp::kDeOpRepeated);
|
||||
leaf_op = PopFromEOEOpStack();
|
||||
}
|
||||
AddToEOEOpStack(std::static_pointer_cast<DatasetOp>(node));
|
||||
|
||||
// adjust the total epochs and total repeats for ops under this cache op
|
||||
std::shared_ptr<DatasetOp> cached_op = PopFromCachedOpStack();
|
||||
while (cached_op != nullptr) {
|
||||
int32_t cached_op_total_repeats = cached_op->op_total_repeats() / num_repeats_;
|
||||
cached_op->set_total_repeats(cached_op_total_repeats);
|
||||
// Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1
|
||||
cached_op->set_num_repeats_per_epoch(cached_op_total_repeats);
|
||||
cached_op = PopFromCachedOpStack();
|
||||
}
|
||||
}
|
||||
|
||||
node->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -145,13 +208,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
|||
Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
|
||||
// If we are in a repeat path, then set our repeated flag
|
||||
if (is_repeated_) {
|
||||
node->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||
|
||||
// if we are a leaf node then save ourself in a stack for the repeat operator above us
|
||||
if (node->IsLeaf()) {
|
||||
AddToEOEOpStack(node);
|
||||
}
|
||||
}
|
||||
if (is_cached_) {
|
||||
AddToCachedOpStack(node);
|
||||
}
|
||||
// Set total repeats and total epochs for the node
|
||||
node->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -159,13 +226,16 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
|
|||
Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
|
||||
// Setting the flag is needed since we didn't call the base class DatasetOp version
|
||||
if (is_repeated_) {
|
||||
node->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||
// If there was not any repeat in the merge cache miss leg, then the cache_lookup
|
||||
// would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack
|
||||
if (cache_lookup_) {
|
||||
cache_lookup_->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
AddToEOEOpStack(std::move(cache_lookup_));
|
||||
}
|
||||
}
|
||||
node->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used
|
||||
is_merge_ = false;
|
||||
return Status::OK();
|
||||
|
@ -178,13 +248,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified
|
|||
RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!");
|
||||
}
|
||||
|
||||
// If we are in a repeat path already, then there must be a repeat above the merge op
|
||||
// In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here.
|
||||
if (is_repeated_) {
|
||||
node->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||
// Delay the assigment of this leap to the eoe stack and allow the merge op processing to handle that.
|
||||
}
|
||||
|
||||
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
|
||||
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
|
||||
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
|
||||
|
@ -197,19 +260,32 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified
|
|||
|
||||
// Adds an operator to the eoe operator stack save area
|
||||
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) {
|
||||
eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
|
||||
op_stack *current_stack = eoe_op_stacks_.top().get();
|
||||
current_stack->push(dataset_op);
|
||||
}
|
||||
|
||||
// Pops an operator from the eoe operator stack save area
|
||||
std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() {
|
||||
std::shared_ptr<DatasetOp> top_op = nullptr;
|
||||
eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
|
||||
op_stack *current_stack = eoe_op_stacks_.top().get();
|
||||
if (current_stack != nullptr && !current_stack->empty()) {
|
||||
top_op = current_stack->top();
|
||||
current_stack->pop();
|
||||
}
|
||||
return top_op;
|
||||
}
|
||||
|
||||
// Adds an operator to the cached operator stack save area
|
||||
void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); }
|
||||
|
||||
// Pops an operator from the cached operator stack save area
|
||||
std::shared_ptr<DatasetOp> RepeatPass::PopFromCachedOpStack() {
|
||||
std::shared_ptr<DatasetOp> top_op = nullptr;
|
||||
if (!cached_op_stacks_.empty()) {
|
||||
top_op = cached_op_stacks_.top();
|
||||
cached_op_stacks_.pop();
|
||||
}
|
||||
return top_op;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace dataset {
|
|||
/// to the eoe-producing (typically leaf) nodes underneath it.
|
||||
class RepeatPass : public NodePass {
|
||||
public:
|
||||
using eoe_op_stack = std::stack<std::shared_ptr<DatasetOp>>;
|
||||
using op_stack = std::stack<std::shared_ptr<DatasetOp>>;
|
||||
|
||||
/// \brief Constructor
|
||||
RepeatPass();
|
||||
|
@ -56,6 +56,12 @@ class RepeatPass : public NodePass {
|
|||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the subtree below this node as being cached
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Hooks up any identified eoe nodes under this repeat.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
|
@ -103,11 +109,24 @@ class RepeatPass : public NodePass {
|
|||
/// \return shared_ptr to the popped operator
|
||||
std::shared_ptr<DatasetOp> PopFromEOEOpStack();
|
||||
|
||||
bool is_repeated_; // T/F if we are processing under a repeat
|
||||
bool is_merge_; // T/F if we are processing under a cache merge op
|
||||
int32_t nested_repeats_; // A counter for nested repeats
|
||||
std::stack<std::unique_ptr<eoe_op_stack>> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting)
|
||||
std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op
|
||||
/// \brief Adds an operator to the cached operator stack save area
|
||||
/// \param op - The dataset op to work add to cached stack
|
||||
/// \return Status - The error code return
|
||||
void AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op);
|
||||
|
||||
/// \brief Pops an operator from the cached operator stack save area
|
||||
/// \return shared_ptr to the popped operator
|
||||
std::shared_ptr<DatasetOp> PopFromCachedOpStack();
|
||||
|
||||
bool is_repeated_; // T/F if we are processing under a repeat
|
||||
bool is_merge_; // T/F if we are processing under a cache merge op
|
||||
bool is_cached_; // T/F is we are processing under a cache op
|
||||
int32_t nested_repeats_; // A counter for nested repeats
|
||||
int32_t num_repeats_; // A multiplier to the total number of repeats
|
||||
int32_t num_epochs_; // To save the total number of epochs
|
||||
std::stack<std::unique_ptr<op_stack>> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting)
|
||||
op_stack cached_op_stacks_; // A save area for ops under a cache op
|
||||
std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -565,6 +565,99 @@ def test_generator_tuple_repeat_repeat_3():
|
|||
|
||||
# rely on garbage collector to destroy iter1
|
||||
|
||||
|
||||
def test_generator_tuple_infinite_repeat_repeat_1():
|
||||
"""
|
||||
test generator tuple infinite repeat repeat 1
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.repeat()
|
||||
data1 = data1.repeat(3)
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=11)
|
||||
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i % 64])
|
||||
np.testing.assert_array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
if i == 100:
|
||||
break
|
||||
|
||||
# rely on garbage collector to destroy iter1
|
||||
|
||||
|
||||
def test_generator_tuple_infinite_repeat_repeat_2():
|
||||
"""
|
||||
test generator tuple infinite repeat repeat 2
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.repeat(3)
|
||||
data1 = data1.repeat()
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=11)
|
||||
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i % 64])
|
||||
np.testing.assert_array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
if i == 100:
|
||||
break
|
||||
|
||||
# rely on garbage collector to destroy iter1
|
||||
|
||||
|
||||
def test_generator_tuple_infinite_repeat_repeat_3():
|
||||
"""
|
||||
test generator tuple infinite repeat repeat 3
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.repeat()
|
||||
data1 = data1.repeat()
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=11)
|
||||
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i % 64])
|
||||
np.testing.assert_array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
if i == 100:
|
||||
break
|
||||
|
||||
# rely on garbage collector to destroy iter1
|
||||
|
||||
|
||||
def test_generator_tuple_infinite_repeat_repeat_4():
|
||||
"""
|
||||
test generator tuple infinite repeat repeat 4
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.repeat()
|
||||
data1 = data1.repeat()
|
||||
iter1 = data1.create_tuple_iterator()
|
||||
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
golden = np.array([i % 64])
|
||||
np.testing.assert_array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
if i == 100:
|
||||
break
|
||||
|
||||
# rely on garbage collector to destroy iter1
|
||||
|
||||
|
||||
def test_generator_reusedataset():
|
||||
"""
|
||||
test generator reusedataset
|
||||
|
|
Loading…
Reference in New Issue