!243 Support nested repeat
Merge pull request !243 from h.farahat/nested_repeat
This commit is contained in:
commit
30de261c3c
|
@ -161,15 +161,18 @@ Status DatasetOp::EofReceived(int32_t worker_id) {
|
|||
return (out_connector_->Add(static_cast<int>(worker_id), std::move(eof_buffer)));
|
||||
}
|
||||
|
||||
// During tree prepare phase, operators may have specific operations to perform depending on
|
||||
// During tree prepare phase, operators may have specific pre-operations to perform depending on
|
||||
// their role.
|
||||
Status DatasetOp::PrepareNodeAction() {
|
||||
Status DatasetOp::PrepareNodePreAction() {
|
||||
if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated);
|
||||
return Status::OK();
|
||||
}
|
||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
||||
// their role.
|
||||
Status DatasetOp::PrepareNodePostAction() {
|
||||
// If this op does not have any children and it is in a repeat path of the tree...
|
||||
if (child_.size() == 0 && BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) {
|
||||
// Then, flag this operator as a leaf node in a repeat path of tree execution.
|
||||
BitSet(&op_ctrl_flags_, kDeOpRepeated);
|
||||
|
||||
// Secondly, push ourselves onto the tree repeat stack. Later, the repeat operator
|
||||
if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) {
|
||||
// push ourselves onto the tree repeat stack. Later, the repeat operator
|
||||
// above us will consume them.
|
||||
tree_->AddToRepeatStack(shared_from_this());
|
||||
}
|
||||
|
|
|
@ -150,11 +150,17 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// During tree prepare phase, operators may have specific operations to perform depending on
|
||||
// During tree prepare phase, operators may have specific pre-operations to perform depending on
|
||||
// their role.
|
||||
// @notes Derived versions of this function should always call it's superclass version first
|
||||
// before providing their own implementations.
|
||||
virtual Status PrepareNodeAction();
|
||||
virtual Status PrepareNodePreAction();
|
||||
|
||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
||||
// their role.
|
||||
// @notes Derived versions of this function should always call it's superclass version first
|
||||
// before providing their own implementations.
|
||||
virtual Status PrepareNodePostAction();
|
||||
|
||||
// Getter function
|
||||
// @return The operator id
|
||||
|
|
|
@ -64,14 +64,24 @@ class ParallelOp : public DatasetOp {
|
|||
return out;
|
||||
}
|
||||
|
||||
// During tree prepare phase, operators may have specific operations to perform depending on
|
||||
// During tree prepare phase, operators may have specific pre-operations to perform depending on
|
||||
// their role.
|
||||
// @notes Derived versions of this function should always call it's superclass version first
|
||||
// before providing their own implementations.
|
||||
// @return Status - The error return code
|
||||
Status PrepareNodeAction() override {
|
||||
Status PrepareNodePreAction() override {
|
||||
// Run common code from super class before adding ParallelOp specific logic
|
||||
return (DatasetOp::PrepareNodeAction());
|
||||
return (DatasetOp::PrepareNodePreAction());
|
||||
}
|
||||
|
||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
||||
// their role.
|
||||
// @notes Derived versions of this function should always call it's superclass version first
|
||||
// before providing their own implementations.
|
||||
// @return Status - The error return code
|
||||
Status PrepareNodePostAction() override {
|
||||
// Run common code from super class before adding ParallelOp specific logic
|
||||
return (DatasetOp::PrepareNodePostAction());
|
||||
}
|
||||
|
||||
// Override base class reset to provide reset actions specific to the ParallelOp class.
|
||||
|
|
|
@ -64,13 +64,22 @@ class PipelineOp : public DatasetOp {
|
|||
// @return The number of threads that push data to the output connector
|
||||
int32_t num_producers() const override { return 1; }
|
||||
|
||||
// During tree prepare phase, operators may have specific operations to perform depending on
|
||||
// During tree prepare phase, operators may have specific pre-operations to perform depending on
|
||||
// their role.
|
||||
// @notes Derived versions of this function should always call it's superclass version first
|
||||
// before providing their own implementations.
|
||||
Status PrepareNodeAction() override {
|
||||
Status PrepareNodePreAction() override {
|
||||
// Run common code from super class before adding PipelineOp specific logic
|
||||
return (DatasetOp::PrepareNodeAction());
|
||||
return (DatasetOp::PrepareNodePreAction());
|
||||
}
|
||||
|
||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
||||
// their role.
|
||||
// @notes Derived versions of this function should always call it's superclass version first
|
||||
// before providing their own implementations.
|
||||
Status PrepareNodePostAction() override {
|
||||
// Run common code from super class before adding PipelineOp specific logic
|
||||
return (DatasetOp::PrepareNodePostAction());
|
||||
}
|
||||
|
||||
protected:
|
||||
|
|
|
@ -58,10 +58,10 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
|
|||
out << "RepeatOp:"
|
||||
<< "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_
|
||||
<< "\nLeaf Nodes in my execution path:";
|
||||
if (!leaf_ops_.empty()) {
|
||||
if (!eoe_ops_.empty()) {
|
||||
out << "\n";
|
||||
for (size_t i = 0; i < leaf_ops_.size(); i++) {
|
||||
out << " Operator: " << leaf_ops_[i]->id() << "\n";
|
||||
for (size_t i = 0; i < eoe_ops_.size(); i++) {
|
||||
out << " Operator: " << eoe_ops_[i]->id() << "\n";
|
||||
}
|
||||
} else {
|
||||
out << " kNone.";
|
||||
|
@ -71,21 +71,17 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
|
|||
|
||||
// Base-class override for executing specific RepeatOp configurations. This code will be called
|
||||
// during the execution tree prepare phase when it is visiting this operator.
|
||||
Status RepeatOp::PrepareNodeAction() {
|
||||
Status RepeatOp::PrepareNodePostAction() {
|
||||
// Run any common code from super class first before adding our own specific logic
|
||||
RETURN_IF_NOT_OK(PipelineOp::PrepareNodeAction());
|
||||
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
|
||||
std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromRepeatStack();
|
||||
while (leaf_op != nullptr) {
|
||||
// Track the leaf operators that are under this repeat op.
|
||||
leaf_ops_.push_back(leaf_op);
|
||||
|
||||
// Special case. If the repeat count is 1, then pre-flag the leaf nodes
|
||||
// to tell them they are already at their last op:
|
||||
if (max_repeats_ == 1) {
|
||||
leaf_op->set_control_flag(kDeOpLastRepeat);
|
||||
}
|
||||
eoe_ops_.push_back(leaf_op);
|
||||
leaf_op = tree_->PopFromRepeatStack();
|
||||
}
|
||||
// Push ourselves to the stack in case one of our ascendants is repeat too.
|
||||
tree_->AddToRepeatStack(shared_from_this());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -127,16 +123,20 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
|
|||
Status RepeatOp::EoeReceived(int32_t worker_id) {
|
||||
repeat_count_++;
|
||||
MS_LOG(INFO) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << ".";
|
||||
|
||||
// If we've reached the requested repeat count, then flag the leaf nodes
|
||||
bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated);
|
||||
bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat);
|
||||
// If we've reached the requested repeat count, then flag the eoe nodes
|
||||
// to tell them they've got one more epoch to perform. When they reach the end
|
||||
// of the last epoch, they quit rather than loop again.
|
||||
if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) {
|
||||
for (size_t i = 0; i < leaf_ops_.size(); i++) {
|
||||
leaf_ops_[i]->set_control_flag(kDeOpLastRepeat);
|
||||
// of the last epoch, they quit rather than loop again. This happens in two cases:
|
||||
// 1- We are also repeated (by another repeat op) and we are at the last repetition. Or,
|
||||
// 2- We are not repeated
|
||||
if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) {
|
||||
for (auto &eoe_op : eoe_ops_) {
|
||||
eoe_op->set_control_flag(kDeOpLastRepeat);
|
||||
}
|
||||
}
|
||||
if (repeat_count_ == max_repeats_) {
|
||||
repeat_count_ = 0;
|
||||
state_ = OpState::kDeOpIdle;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -87,8 +87,8 @@ class RepeatOp : public PipelineOp {
|
|||
uint32_t PrepareFlags() const override;
|
||||
|
||||
// Base-class override for executing specific RepeatOp configurations. This code will be called
|
||||
// during the execution tree prepare phase when it is visiting this operator.
|
||||
Status PrepareNodeAction() override;
|
||||
// during the execution tree post-prepare phase when it is visiting this operator.
|
||||
Status PrepareNodePostAction() override;
|
||||
|
||||
// This function returns the buffer that is at the top of our output connector. The caller is
|
||||
// typically our parent node, when the parent is asking us to provide the next buffer of data.
|
||||
|
@ -119,9 +119,9 @@ class RepeatOp : public PipelineOp {
|
|||
int32_t num_producers() const override;
|
||||
|
||||
private:
|
||||
int32_t max_repeats_; // The number of repeats that the user requested
|
||||
int32_t repeat_count_; // A counter for the current number of executed repeats
|
||||
std::vector<std::shared_ptr<DatasetOp>> leaf_ops_; // List of leaf operators underneath this repeat.
|
||||
int32_t max_repeats_; // The number of repeats that the user requested
|
||||
int32_t repeat_count_; // A counter for the current number of executed repeats
|
||||
std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat.
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -162,30 +162,25 @@ Status ExecutionTree::Prepare() {
|
|||
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
|
||||
// node actions during a tree walk.
|
||||
Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op) {
|
||||
int32_t num_children = dataset_op->child_.size();
|
||||
// execute PreAction
|
||||
RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction());
|
||||
|
||||
// Before going down into children, make any prepare flags updates based on this
|
||||
// operator.
|
||||
// Before going down into children, make any prepare flags updates based on this operator.
|
||||
uint32_t op_prep_flags = dataset_op->PrepareFlags();
|
||||
// Sanity check. In future we can support nested repeats. for now it's not allowed.
|
||||
// If somebody above us already set the repeat flag, and now we are another repeat...
|
||||
if (BitTest(op_prep_flags, kDePrepRepeat) && BitTest(prepare_flags_, kDePrepRepeat)) {
|
||||
std::string err_msg("Nested RepeatOp detected! This is not supported yet.");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
BitSet(&prepare_flags_, op_prep_flags);
|
||||
|
||||
// Now, descend to children
|
||||
for (int32_t i = 0; i < num_children; ++i) {
|
||||
RETURN_IF_NOT_OK(this->PrepareNode(dataset_op->child_[i]));
|
||||
for (const auto &i : dataset_op->child_) {
|
||||
RETURN_IF_NOT_OK(this->PrepareNode(i));
|
||||
}
|
||||
|
||||
// No more children, now we execute any prepare actions before going back up the
|
||||
// the tree on recursive function exit
|
||||
RETURN_IF_NOT_OK(dataset_op->PrepareNodeAction());
|
||||
|
||||
// Then clear the flags from this op now that we have prepared it.
|
||||
BitClear(&prepare_flags_, op_prep_flags);
|
||||
|
||||
// No more children, now we execute any prepare actions before going back up the
|
||||
// the tree on recursive function
|
||||
RETURN_IF_NOT_OK(dataset_op->PrepareNodePostAction());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -419,6 +419,8 @@ class Dataset:
|
|||
>>> repeat_and_shuffle = data.repeat(50)
|
||||
>>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10)
|
||||
"""
|
||||
if count == 1:
|
||||
return self
|
||||
return RepeatDataset(self, count)
|
||||
|
||||
@check_zip_dataset
|
||||
|
|
|
@ -33,18 +33,29 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) {
|
|||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::shared_ptr<DatasetOp> parent_op = std::make_shared<RepeatOp>(32);
|
||||
|
||||
std::shared_ptr<DatasetOp> leaf_op = std::make_shared<RepeatOp>(16);
|
||||
std::string dataset_path;
|
||||
dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data";
|
||||
// TFReaderOp
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
TFReaderOp::Builder builder;
|
||||
builder.SetDatasetFilesList({dataset_path})
|
||||
.SetRowsPerBuffer(16)
|
||||
.SetWorkerConnectorSize(16)
|
||||
.SetNumWorkers(16);
|
||||
Status rc= builder.Build(&my_tfreader_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(my_tfreader_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
my_tree->AssociateNode(parent_op);
|
||||
my_tree->AssociateNode(leaf_op);
|
||||
ASSERT_NE(parent_op, nullptr);
|
||||
ASSERT_NE(leaf_op, nullptr);
|
||||
parent_op->AddChild(std::move(leaf_op));
|
||||
parent_op->Print(std::cout, false);
|
||||
parent_op->PrepareNodeAction();
|
||||
ASSERT_NE(my_tfreader_op, nullptr);
|
||||
parent_op->AddChild(std::move(my_tfreader_op));
|
||||
MS_LOG(INFO) << parent_op;
|
||||
my_tree->Prepare();
|
||||
|
||||
RepeatOp RepeatOpOp();
|
||||
|
||||
std::shared_ptr<RepeatOp> repeat_op;
|
||||
Status rc = RepeatOp::Builder(3).Build(&repeat_op);
|
||||
rc = RepeatOp::Builder(3).Build(&repeat_op);
|
||||
ASSERT_NE(repeat_op, nullptr);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision
|
|||
from util import save_and_check
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"]
|
||||
|
@ -95,6 +96,141 @@ def test_tf_repeat_03():
|
|||
assert num_iter == 2
|
||||
|
||||
|
||||
def generator():
|
||||
for i in range(3):
|
||||
yield np.array([i]),
|
||||
|
||||
|
||||
def test_nested_repeat1():
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(2)
|
||||
data = data.repeat(3)
|
||||
|
||||
for i, d in enumerate(data):
|
||||
assert i % 3 == d[0][0]
|
||||
|
||||
assert sum([1 for _ in data]) == 2 * 3 * 3
|
||||
|
||||
|
||||
def test_nested_repeat2():
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(1)
|
||||
data = data.repeat(1)
|
||||
|
||||
for i, d in enumerate(data):
|
||||
assert i % 3 == d[0][0]
|
||||
|
||||
assert sum([1 for _ in data]) == 3
|
||||
|
||||
|
||||
def test_nested_repeat3():
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(1)
|
||||
data = data.repeat(2)
|
||||
|
||||
for i, d in enumerate(data):
|
||||
assert i % 3 == d[0][0]
|
||||
|
||||
assert sum([1 for _ in data]) == 2 * 3
|
||||
|
||||
|
||||
def test_nested_repeat4():
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(2)
|
||||
data = data.repeat(1)
|
||||
|
||||
for i, d in enumerate(data):
|
||||
assert i % 3 == d[0][0]
|
||||
|
||||
assert sum([1 for _ in data]) == 2 * 3
|
||||
|
||||
|
||||
def test_nested_repeat5():
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.batch(3)
|
||||
data = data.repeat(2)
|
||||
data = data.repeat(3)
|
||||
|
||||
for i, d in enumerate(data):
|
||||
assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))
|
||||
|
||||
assert sum([1 for _ in data]) == 6
|
||||
|
||||
|
||||
def test_nested_repeat6():
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(2)
|
||||
data = data.batch(3)
|
||||
data = data.repeat(3)
|
||||
|
||||
for i, d in enumerate(data):
|
||||
assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))
|
||||
|
||||
assert sum([1 for _ in data]) == 6
|
||||
|
||||
|
||||
def test_nested_repeat7():
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(2)
|
||||
data = data.repeat(3)
|
||||
data = data.batch(3)
|
||||
|
||||
for i, d in enumerate(data):
|
||||
assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))
|
||||
|
||||
assert sum([1 for _ in data]) == 6
|
||||
|
||||
|
||||
def test_nested_repeat8():
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.batch(2, drop_remainder=False)
|
||||
data = data.repeat(2)
|
||||
data = data.repeat(3)
|
||||
|
||||
for i, d in enumerate(data):
|
||||
if i % 2 == 0:
|
||||
assert np.array_equal(d[0], np.asarray([[0], [1]]))
|
||||
else:
|
||||
assert np.array_equal(d[0], np.asarray([[2]]))
|
||||
|
||||
assert sum([1 for _ in data]) == 6 * 2
|
||||
|
||||
|
||||
def test_nested_repeat9():
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat()
|
||||
data = data.repeat(3)
|
||||
|
||||
for i, d in enumerate(data):
|
||||
assert i % 3 == d[0][0]
|
||||
if i == 10:
|
||||
break
|
||||
|
||||
|
||||
def test_nested_repeat10():
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(3)
|
||||
data = data.repeat()
|
||||
|
||||
for i, d in enumerate(data):
|
||||
assert i % 3 == d[0][0]
|
||||
if i == 10:
|
||||
break
|
||||
|
||||
|
||||
def test_nested_repeat11():
|
||||
data = ds.GeneratorDataset(generator, ["data"])
|
||||
data = data.repeat(2)
|
||||
data = data.repeat(3)
|
||||
data = data.repeat(4)
|
||||
data = data.repeat(5)
|
||||
|
||||
for i, d in enumerate(data):
|
||||
assert i % 3 == d[0][0]
|
||||
|
||||
assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("--------test tf repeat 01---------")
|
||||
# test_repeat_01()
|
||||
|
@ -104,4 +240,3 @@ if __name__ == "__main__":
|
|||
|
||||
logger.info("--------test tf repeat 03---------")
|
||||
test_tf_repeat_03()
|
||||
|
||||
|
|
Loading…
Reference in New Issue