forked from mindspore-Ecosystem/mindspore
!483 Optimize skip dataset op
Merge pull request !483 from jiangzhiwen/dataset/skip_opt
This commit is contained in:
commit
d9e4dcc33b
|
@ -67,9 +67,10 @@ Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work
|
|||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> buf;
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
|
||||
|
||||
// Drop first max_skips_ rows
|
||||
while (skip_count_ < max_skips_) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
|
||||
if (buf->eoe() || buf->eof()) {
|
||||
break;
|
||||
}
|
||||
|
@ -77,31 +78,24 @@ Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work
|
|||
// Consider the rows of buffer more than 1
|
||||
TensorRow drop_row;
|
||||
int row_num = buf->NumRows();
|
||||
for (int i = 0; i < row_num; i++) {
|
||||
int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_;
|
||||
skip_count_ += drop_num;
|
||||
for (int i = 0; i < drop_num; i++) {
|
||||
RETURN_IF_NOT_OK(buf->PopRow(&drop_row));
|
||||
if (++skip_count_ == max_skips_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (buf->NumRows() == 0) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
|
||||
}
|
||||
}
|
||||
|
||||
// If buffer is none or the rows of buffer is 0,
|
||||
// then get a buffer from child.
|
||||
if (!buf || buf->NumRows() == 0) {
|
||||
if (buf && buf->eof()) {
|
||||
*p_buffer = std::move(buf);
|
||||
return Status::OK();
|
||||
}
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
|
||||
}
|
||||
|
||||
// Handling eoe and eof
|
||||
if (buf->eoe() || buf->eof()) {
|
||||
// Handling eoe
|
||||
if (buf->eoe()) {
|
||||
RETURN_IF_NOT_OK(EoeReceived(worker_id));
|
||||
if (state_ == OpState::kDeOpIdle) {
|
||||
*p_buffer = std::move(buf);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// Handling eof
|
||||
if (buf->eof()) {
|
||||
RETURN_IF_NOT_OK(EofReceived(worker_id));
|
||||
}
|
||||
|
||||
*p_buffer = std::move(buf);
|
||||
|
@ -125,7 +119,7 @@ Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is a
|
|||
|
||||
// Base-class override for handling cases when an eof is received.
|
||||
Status SkipOp::EofReceived(int32_t worker_id) {
|
||||
MS_LOG(INFO) << "Skip operator EOF received, do nothing now.";
|
||||
MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now.";
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -22,7 +22,11 @@ from mindspore import log as logger
|
|||
DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
|
||||
def test_tf_skip():
|
||||
"""
|
||||
a simple skip operation.
|
||||
"""
|
||||
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
|
||||
|
||||
resize_height, resize_width = 32, 32
|
||||
|
@ -37,11 +41,15 @@ def test_tf_skip():
|
|||
num_iter += 1
|
||||
assert num_iter == 1
|
||||
|
||||
|
||||
def generator_md():
|
||||
# Create a dataset with [0, 1, 2, 3, 4]
|
||||
"""
|
||||
create a dataset with [0, 1, 2, 3, 4]
|
||||
"""
|
||||
for i in range(5):
|
||||
yield (np.array([i]), )
|
||||
|
||||
|
||||
def test_generator_skip():
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
|
||||
|
@ -53,6 +61,7 @@ def test_generator_skip():
|
|||
buf.append(data[0][0])
|
||||
assert len(buf) == 2
|
||||
|
||||
|
||||
def test_skip_1():
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
|
||||
|
@ -64,6 +73,7 @@ def test_skip_1():
|
|||
buf.append(data[0][0])
|
||||
assert len(buf) == 0
|
||||
|
||||
|
||||
def test_skip_2():
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
|
||||
|
@ -75,6 +85,7 @@ def test_skip_2():
|
|||
buf.append(data[0][0])
|
||||
assert len(buf) == 5
|
||||
|
||||
|
||||
def test_skip_repeat_1():
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
|
||||
|
@ -89,6 +100,7 @@ def test_skip_repeat_1():
|
|||
buf.append(data[0][0])
|
||||
assert len(buf) == 7
|
||||
|
||||
|
||||
def test_skip_repeat_2():
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
|
||||
|
@ -103,6 +115,7 @@ def test_skip_repeat_2():
|
|||
buf.append(data[0][0])
|
||||
assert len(buf) == 4
|
||||
|
||||
|
||||
def test_skip_repeat_3():
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
|
||||
|
@ -120,6 +133,7 @@ def test_skip_repeat_3():
|
|||
buf.append(data[0][0])
|
||||
assert len(buf) == 6
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tf_skip()
|
||||
test_generator_skip()
|
||||
|
@ -127,4 +141,4 @@ if __name__ == "__main__":
|
|||
test_skip_2()
|
||||
test_skip_repeat_1()
|
||||
test_skip_repeat_2()
|
||||
test_skip_repeat_3()
|
||||
test_skip_repeat_3()
|
||||
|
|
Loading…
Reference in New Issue