forked from mindspore-Ecosystem/mindspore
fix skip
This commit is contained in:
parent
9399dffe0e
commit
34bfa2f7c9
|
@ -16,6 +16,7 @@
|
|||
#include <iostream>
|
||||
#include <utility>
|
||||
|
||||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/datasetops/skip_op.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
|
@ -26,7 +27,10 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Builder constructor. Creates the builder object.
|
||||
SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {}
|
||||
SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_op_connector_size_ = cfg->op_connector_size();
|
||||
}
|
||||
|
||||
Status SkipOp::Builder::SanityCheck() const {
|
||||
if (build_max_skips_ < 0) {
|
||||
|
@ -39,12 +43,13 @@ Status SkipOp::Builder::SanityCheck() const {
|
|||
// The builder "build" method creates the final object.
|
||||
Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<SkipOp>(build_max_skips_);
|
||||
*ptr = std::make_shared<SkipOp>(build_max_skips_, builder_op_connector_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor of the SkipOp.
|
||||
SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {}
|
||||
SkipOp::SkipOp(int32_t count, int32_t op_connector_size)
|
||||
: PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {}
|
||||
|
||||
// Destructor
|
||||
SkipOp::~SkipOp() {}
|
||||
|
@ -59,49 +64,6 @@ void SkipOp::Print(std::ostream &out, bool show_all) const {
|
|||
<< "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_;
|
||||
}
|
||||
|
||||
// Since the buffer may contain multi rows, this function will drop the rows
|
||||
// that need to skip in it, and then return the buffer.
|
||||
Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
|
||||
if (child_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("SkipOp can't be the leaf node.");
|
||||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> buf;
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
|
||||
|
||||
// Drop first max_skips_ rows
|
||||
while (skip_count_ < max_skips_) {
|
||||
if (buf->eoe() || buf->eof()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Consider the rows of buffer more than 1
|
||||
TensorRow drop_row;
|
||||
int row_num = buf->NumRows();
|
||||
int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_;
|
||||
skip_count_ += drop_num;
|
||||
for (int i = 0; i < drop_num; i++) {
|
||||
RETURN_IF_NOT_OK(buf->PopRow(&drop_row));
|
||||
}
|
||||
if (buf->NumRows() == 0) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
|
||||
}
|
||||
}
|
||||
|
||||
// Handling eoe
|
||||
if (buf->eoe()) {
|
||||
RETURN_IF_NOT_OK(EoeReceived(worker_id));
|
||||
}
|
||||
|
||||
// Handling eof
|
||||
if (buf->eof()) {
|
||||
RETURN_IF_NOT_OK(EofReceived(worker_id));
|
||||
}
|
||||
|
||||
*p_buffer = std::move(buf);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Base-class override for handling cases when an eoe is received.
|
||||
Status SkipOp::EoeReceived(int32_t worker_id) {
|
||||
skip_count_ = 0;
|
||||
|
@ -109,13 +71,45 @@ Status SkipOp::EoeReceived(int32_t worker_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Class functor operator () override.
|
||||
// Most dataset ops operate by launching a thread (see ExecutionTree).
|
||||
// However, the SkipOp is defined as a inlined operator, so it is invalid to
|
||||
// launch the functor since this op runs inlined inside another operator. The
|
||||
// function is overloaded to ensure that it is not called by mistake (it will
|
||||
// generate an error).
|
||||
Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); }
|
||||
// main entry point for skip
|
||||
Status SkipOp::operator()() {
|
||||
TaskManager::FindMe()->Post();
|
||||
std::unique_ptr<DataBuffer> curr_buffer;
|
||||
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
|
||||
while (curr_buffer->eof() == false) {
|
||||
// Reset count
|
||||
skip_count_ = 0;
|
||||
while (curr_buffer->eoe() == false) {
|
||||
// Drop first count rows
|
||||
while (skip_count_ < max_skips_) {
|
||||
if (curr_buffer->eoe() || curr_buffer->eof()) {
|
||||
break;
|
||||
}
|
||||
// Consider the rows of buffer more than one
|
||||
TensorRow drop_row;
|
||||
int row_num = curr_buffer->NumRows();
|
||||
int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_;
|
||||
skip_count_ += drop_num;
|
||||
for (int i = 0; i < drop_num; i++) {
|
||||
RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row));
|
||||
}
|
||||
if (curr_buffer->NumRows() == 0) {
|
||||
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer)));
|
||||
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
|
||||
}
|
||||
// we got eoe, now try again until we got eof
|
||||
MS_LOG(DEBUG) << "Skip operator EOE Received.";
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
|
||||
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Skip operator EOF Received.";
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Base-class override for handling cases when an eof is received.
|
||||
Status SkipOp::EofReceived(int32_t worker_id) {
|
||||
|
|
|
@ -42,6 +42,7 @@ class SkipOp : public PipelineOp {
|
|||
|
||||
private:
|
||||
int32_t build_max_skips_;
|
||||
int32_t builder_op_connector_size_;
|
||||
|
||||
Status SanityCheck() const;
|
||||
};
|
||||
|
@ -49,7 +50,7 @@ class SkipOp : public PipelineOp {
|
|||
// Constructor of the SkipOp.
|
||||
// @note The builder class should be used to call it
|
||||
// @param count - The number of skips to do
|
||||
explicit SkipOp(int32_t count);
|
||||
explicit SkipOp(int32_t count, int32_t op_connector_size);
|
||||
|
||||
// Destructor
|
||||
~SkipOp();
|
||||
|
@ -60,23 +61,11 @@ class SkipOp : public PipelineOp {
|
|||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// Class functor operator () override.
|
||||
// Most dataset ops operate by launching a thread (see ExecutionTree).
|
||||
// However, the SkipOp is defined as a inlined operator, so it is invalid to launch the
|
||||
// functor since this op runs inlined inside another operator. The function is overloaded to
|
||||
// ensure that it is not called by mistake (it will generate an error).
|
||||
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// This function returns the buffer that is at the top of our output connector. The caller is
|
||||
// typically our parent node, when the parent is asking us to provide the next buffer of data.
|
||||
// Since SkipOp is an inlined op, getting a buffer from us will simply bounce you to get
|
||||
// a buffer from our child.
|
||||
// @param p_buffer - output pointer to the buffer that it will fetch.
|
||||
// @param worker_id - The worker id
|
||||
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
|
||||
// @return Status - The error code return
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
|
||||
|
||||
// Base-class override for handling cases when an eoe is received.
|
||||
// @param worker_id - The worker id
|
||||
Status EoeReceived(int32_t worker_id) override;
|
||||
|
|
|
@ -47,7 +47,7 @@ TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) {
|
|||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// SkipOp
|
||||
std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5);
|
||||
std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5, 2);
|
||||
rc = my_tree->AssociateNode(skip_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
|
@ -51,7 +50,7 @@ def generator_md():
|
|||
|
||||
|
||||
def test_generator_skip():
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"], num_parallel_workers=4)
|
||||
|
||||
# Here ds1 should be [3, 4]
|
||||
ds1 = ds1.skip(3)
|
||||
|
@ -60,6 +59,7 @@ def test_generator_skip():
|
|||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 2
|
||||
assert buf == [3, 4]
|
||||
|
||||
|
||||
def test_skip_1():
|
||||
|
@ -72,6 +72,7 @@ def test_skip_1():
|
|||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 0
|
||||
assert buf == []
|
||||
|
||||
|
||||
def test_skip_2():
|
||||
|
@ -84,6 +85,7 @@ def test_skip_2():
|
|||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 5
|
||||
assert buf == [0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_skip_repeat_1():
|
||||
|
@ -99,6 +101,7 @@ def test_skip_repeat_1():
|
|||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 7
|
||||
assert buf == [3, 4, 0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_skip_repeat_2():
|
||||
|
@ -114,6 +117,7 @@ def test_skip_repeat_2():
|
|||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 4
|
||||
assert buf == [3, 4, 3, 4]
|
||||
|
||||
|
||||
def test_skip_repeat_3():
|
||||
|
@ -132,6 +136,62 @@ def test_skip_repeat_3():
|
|||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 6
|
||||
assert buf == [3, 4, 3, 4, 3, 4]
|
||||
|
||||
def test_skip_take_1():
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
|
||||
# Here ds1 should be [0, 1, 2, 3]
|
||||
ds1 = ds1.take(4)
|
||||
|
||||
# Here ds1 should be [2, 3]
|
||||
ds1 = ds1.skip(2)
|
||||
|
||||
buf = []
|
||||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 2
|
||||
assert buf == [2, 3]
|
||||
|
||||
def test_skip_take_2():
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
|
||||
# Here ds1 should be [2, 3, 4]
|
||||
ds1 = ds1.skip(2)
|
||||
|
||||
# Here ds1 should be [2, 3]
|
||||
ds1 = ds1.take(2)
|
||||
|
||||
buf = []
|
||||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 2
|
||||
assert buf == [2, 3]
|
||||
|
||||
|
||||
def generator_1d():
|
||||
for i in range(64):
|
||||
yield (np.array([i]), )
|
||||
|
||||
def test_skip_filter_1():
|
||||
dataset = ds.GeneratorDataset(generator_1d, ['data'])
|
||||
dataset = dataset.skip(5)
|
||||
dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
|
||||
|
||||
buf = []
|
||||
for item in dataset:
|
||||
buf.append(item[0][0])
|
||||
assert buf == [5, 6, 7, 8, 9, 10]
|
||||
|
||||
def test_skip_filter_2():
|
||||
dataset = ds.GeneratorDataset(generator_1d, ['data'])
|
||||
dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
|
||||
dataset = dataset.skip(5)
|
||||
|
||||
buf = []
|
||||
for item in dataset:
|
||||
buf.append(item[0][0])
|
||||
assert buf == [5, 6, 7, 8, 9, 10]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -142,3 +202,7 @@ if __name__ == "__main__":
|
|||
test_skip_repeat_1()
|
||||
test_skip_repeat_2()
|
||||
test_skip_repeat_3()
|
||||
test_skip_take_1()
|
||||
test_skip_take_2()
|
||||
test_skip_filter_1()
|
||||
test_skip_filter_2()
|
||||
|
|
Loading…
Reference in New Issue