!483 Optimize skip dataset op

Merge pull request !483 from jiangzhiwen/dataset/skip_opt
This commit is contained in:
mindspore-ci-bot 2020-04-21 15:10:23 +08:00 committed by Gitee
commit d9e4dcc33b
2 changed files with 32 additions and 24 deletions

View File

@ -67,9 +67,10 @@ Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work
}
std::unique_ptr<DataBuffer> buf;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
// Drop first max_skips_ rows
while (skip_count_ < max_skips_) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
if (buf->eoe() || buf->eof()) {
break;
}
@ -77,31 +78,24 @@ Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work
// Consider the rows of buffer more than 1
TensorRow drop_row;
int row_num = buf->NumRows();
for (int i = 0; i < row_num; i++) {
int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_;
skip_count_ += drop_num;
for (int i = 0; i < drop_num; i++) {
RETURN_IF_NOT_OK(buf->PopRow(&drop_row));
if (++skip_count_ == max_skips_) {
break;
}
}
if (buf->NumRows() == 0) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
}
}
// If buffer is none or the rows of buffer is 0,
// then get a buffer from child.
if (!buf || buf->NumRows() == 0) {
if (buf && buf->eof()) {
*p_buffer = std::move(buf);
return Status::OK();
}
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
}
// Handling eoe and eof
if (buf->eoe() || buf->eof()) {
// Handling eoe
if (buf->eoe()) {
RETURN_IF_NOT_OK(EoeReceived(worker_id));
if (state_ == OpState::kDeOpIdle) {
*p_buffer = std::move(buf);
return Status::OK();
}
}
// Handling eof
if (buf->eof()) {
RETURN_IF_NOT_OK(EofReceived(worker_id));
}
*p_buffer = std::move(buf);
@ -125,7 +119,7 @@ Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is a
// Base-class override for handling cases when an eof is received.
Status SkipOp::EofReceived(int32_t worker_id) {
MS_LOG(INFO) << "Skip operator EOF received, do nothing now.";
MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now.";
return Status::OK();
}
} // namespace dataset

View File

@ -22,7 +22,11 @@ from mindspore import log as logger
DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_tf_skip():
"""
a simple skip operation.
"""
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
resize_height, resize_width = 32, 32
@ -37,11 +41,15 @@ def test_tf_skip():
num_iter += 1
assert num_iter == 1
def generator_md():
# Create a dataset with [0, 1, 2, 3, 4]
"""
create a dataset with [0, 1, 2, 3, 4]
"""
for i in range(5):
yield (np.array([i]), )
def test_generator_skip():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
@ -53,6 +61,7 @@ def test_generator_skip():
buf.append(data[0][0])
assert len(buf) == 2
def test_skip_1():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
@ -64,6 +73,7 @@ def test_skip_1():
buf.append(data[0][0])
assert len(buf) == 0
def test_skip_2():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
@ -75,6 +85,7 @@ def test_skip_2():
buf.append(data[0][0])
assert len(buf) == 5
def test_skip_repeat_1():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
@ -89,6 +100,7 @@ def test_skip_repeat_1():
buf.append(data[0][0])
assert len(buf) == 7
def test_skip_repeat_2():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
@ -103,6 +115,7 @@ def test_skip_repeat_2():
buf.append(data[0][0])
assert len(buf) == 4
def test_skip_repeat_3():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
@ -120,6 +133,7 @@ def test_skip_repeat_3():
buf.append(data[0][0])
assert len(buf) == 6
if __name__ == "__main__":
test_tf_skip()
test_generator_skip()
@ -127,4 +141,4 @@ if __name__ == "__main__":
test_skip_2()
test_skip_repeat_1()
test_skip_repeat_2()
test_skip_repeat_3()
test_skip_repeat_3()