diff --git a/mindspore/ccsrc/minddata/dataset/util/queue.h b/mindspore/ccsrc/minddata/dataset/util/queue.h index c7cd414bfab..e3d5a6d45b7 100644 --- a/mindspore/ccsrc/minddata/dataset/util/queue.h +++ b/mindspore/ccsrc/minddata/dataset/util/queue.h @@ -69,6 +69,7 @@ class Queue { void Reset() { std::unique_lock _lock(mux_); ResetQue(); + extra_arr_.clear(); } // Producer @@ -91,8 +92,7 @@ class Queue { // Block when full Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); }); if (rc.IsOk()) { - auto k = tail_++ % sz_; - *(arr_[k]) = std::forward(ele); + RETURN_IF_NOT_OK(AddWhileHoldingLock(std::forward(ele))); empty_cv_.NotifyAll(); _lock.unlock(); } else { @@ -123,7 +123,7 @@ class Queue { // Block when empty Status rc = empty_cv_.Wait(&_lock, [this]() -> bool { return !empty(); }); if (rc.IsOk()) { - RETURN_IF_NOT_OK(PopFrontWhileHoldingLock(p)); + RETURN_IF_NOT_OK(PopFrontWhileHoldingLock(p, true)); full_cv_.NotifyAll(); _lock.unlock(); } else { @@ -144,19 +144,39 @@ class Queue { Status Resize(int32_t new_capacity) { std::unique_lock _lock(mux_); - CHECK_FAIL_RETURN_UNEXPECTED( - new_capacity >= static_cast(size()), - "New capacity: " + std::to_string(new_capacity) + ", is smaller than queue size:" + std::to_string(size())); + CHECK_FAIL_RETURN_UNEXPECTED(new_capacity > 0, + "New capacity: " + std::to_string(new_capacity) + ", should be larger than 0"); + RETURN_OK_IF_TRUE(new_capacity == capacity()); std::vector queue; + // pop from the original queue until the new_capacity is full + for (int32_t i = 0; i < new_capacity; ++i) { + if (head_ < tail_) { + // if there are elements left in queue, pop out + T temp; + RETURN_IF_NOT_OK(this->PopFrontWhileHoldingLock(&temp, true)); + queue.push_back(temp); + } else { + // if there is nothing left in queue, check extra_arr_ + if (!extra_arr_.empty()) { + // if extra_arr_ is not empty, push to fill the new_capacity + queue.push_back(extra_arr_[0]); + extra_arr_.erase(extra_arr_.begin()); + } else { + // if everything in the queue and extra_arr_ is popped out, break the loop + break; + } + } + } + // if there are extra elements in queue, put them to extra_arr_ while (head_ < tail_) { T temp; - RETURN_IF_NOT_OK(this->PopFrontWhileHoldingLock(&temp)); - queue.push_back(temp); + RETURN_IF_NOT_OK(this->PopFrontWhileHoldingLock(&temp, false)); + extra_arr_.push_back(temp); } this->ResetQue(); RETURN_IF_NOT_OK(arr_.allocate(new_capacity)); sz_ = new_capacity; - for (int i = 0; i < queue.size(); ++i) { + for (int32_t i = 0; i < queue.size(); ++i) { RETURN_IF_NOT_OK(this->AddWhileHoldingLock(queue[i])); } queue.clear(); @@ -167,6 +187,8 @@ class Queue { private: size_t sz_; MemGuard> arr_; + std::vector extra_arr_; // used to store extra elements after reducing capacity, will not be changed by Add, + // will pop when there is a space in queue (by PopFront or Resize) size_t head_; size_t tail_; std::string my_name_; @@ -181,17 +203,28 @@ class Queue { return Status::OK(); } + // Helper function for Add, must be called when holding a lock + Status AddWhileHoldingLock(T &&ele) { + auto k = tail_++ % sz_; + *(arr_[k]) = std::forward(ele); + return Status::OK(); + } + // Helper function for PopFront, must be called when holding a lock - Status PopFrontWhileHoldingLock(pointer p) { + Status PopFrontWhileHoldingLock(pointer p, bool clean_extra) { auto k = head_++ % sz_; *p = std::move(*(arr_[k])); + if (!extra_arr_.empty() && clean_extra) { + RETURN_IF_NOT_OK(this->AddWhileHoldingLock(std::forward(extra_arr_[0]))); + extra_arr_.erase(extra_arr_.begin()); + } return Status::OK(); } void ResetQue() noexcept { while (head_ < tail_) { T val; - this->PopFrontWhileHoldingLock(&val); + this->PopFrontWhileHoldingLock(&val, false); MS_LOG(DEBUG) << "Address of val: " << &val; } empty_cv_.ResetIntrpState(); diff --git a/tests/ut/cpp/dataset/queue_test.cc b/tests/ut/cpp/dataset/queue_test.cc index 93aeaff09af..5438ca09dfc 100644 --- a/tests/ut/cpp/dataset/queue_test.cc +++ b/tests/ut/cpp/dataset/queue_test.cc @@ -178,10 +178,10 @@ TEST_F(MindDataTestQueue, Test6) { ASSERT_EQ(*pepped_value, 99); } -// Feature: Check resize is finished without changing elements and influencing operations. -// Description: Compare elements in queue before and after resize, and test add/pop/reset. -// Expectation: Elements in queue after resize are the same as the original queue. -TEST_F(MindDataTestQueue, TestResize) { +// Feature: Test basic check in the resize. +// Description: Check false input for resize function. +// Expectation: Return false when the input is unexpected, and true when the new capacity is the same as original. +TEST_F(MindDataTestQueue, TestResize1) { // Create a list of queues with capacity = 3 Queue queue(3); ASSERT_EQ(3, queue.capacity()); @@ -189,38 +189,77 @@ TEST_F(MindDataTestQueue, TestResize) { TensorRow a; std::shared_ptr test_tensor1; std::vector input = {1.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1.2, 0.7, 0.8, 0.9, 1.0, 2.0, 1.3, 3.0, 4.0}; - ASSERT_OK(Tensor::CreateFromVector(input, TensorShape{3, 5}, &test_tensor1)); + EXPECT_OK(Tensor::CreateFromVector(input, TensorShape{3, 5}, &test_tensor1)); a.push_back(test_tensor1); EXPECT_OK(queue.Add(a)); TensorRow b; std::shared_ptr test_tensor2; - ASSERT_OK(Tensor::CreateScalar(true, &test_tensor2)); + EXPECT_OK(Tensor::CreateScalar(true, &test_tensor2)); b.push_back(test_tensor2); EXPECT_OK(queue.Add(b)); TensorRow c; std::shared_ptr test_tensor3; - ASSERT_OK(Tensor::CreateFromVector(input, &test_tensor3)); + EXPECT_OK(Tensor::CreateFromVector(input, &test_tensor3)); c.push_back(test_tensor3); EXPECT_OK(queue.Add(c)); ASSERT_EQ(3, queue.size()); - // Check false if the resize is smaller than current size - EXPECT_ERROR(queue.Resize(2)); + + // Check false if input is equal to or smaller than 0 + EXPECT_ERROR(queue.Resize(0)); + EXPECT_ERROR(queue.Resize(-1)); + // Check true if the new capacity is the same as original + EXPECT_OK(queue.Resize(3)); +} + +// Feature: Check resize is finished without changing elements and influencing operations. +// Description: Compare elements in queue before and after resize, and test add/pop/reset. +// Expectation: Elements in queue after resize are the same as the original queue. +TEST_F(MindDataTestQueue, TestResize2) { + // Create a list of queues with capacity = 3 + Queue queue(3); + ASSERT_EQ(3, queue.capacity()); + // Add 3 rows into the queue + TensorRow a; + std::shared_ptr test_tensor1; + std::vector input = {1.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1.2, 0.7, 0.8, 0.9, 1.0, 2.0, 1.3, 3.0, 4.0}; + EXPECT_OK(Tensor::CreateFromVector(input, TensorShape{3, 5}, &test_tensor1)); + a.push_back(test_tensor1); + EXPECT_OK(queue.Add(a)); + + TensorRow b; + std::shared_ptr test_tensor2; + EXPECT_OK(Tensor::CreateScalar(true, &test_tensor2)); + b.push_back(test_tensor2); + EXPECT_OK(queue.Add(b)); + + TensorRow c; + std::shared_ptr test_tensor3; + EXPECT_OK(Tensor::CreateFromVector(input, &test_tensor3)); + c.push_back(test_tensor3); + EXPECT_OK(queue.Add(c)); + ASSERT_EQ(3, queue.size()); + + // Check true if the resize is smaller than current size + EXPECT_OK(queue.Resize(1)); + ASSERT_EQ(1, queue.capacity()); + // Expect the rows after resize are the same as original input, there should be still 1 element in the queue + TensorRow d; + EXPECT_OK(queue.PopFront(&d)); + EXPECT_EQ(a.getRow(), d.getRow()); + ASSERT_EQ(1, queue.size()); // Check true if the resize is larger than current size, and capacity is changed EXPECT_OK(queue.Resize(12)); ASSERT_EQ(12, queue.capacity()); - TensorRow d = a; - EXPECT_OK(queue.Add(d)); - ASSERT_EQ(4, queue.size()); - // Expect the rows after resize are the same as original input - TensorRow e; - EXPECT_OK(queue.PopFront(&e)); - EXPECT_EQ(a.getRow(), e.getRow()); - EXPECT_OK(queue.PopFront(&e)); - EXPECT_EQ(b.getRow(), e.getRow()); - EXPECT_OK(queue.PopFront(&e)); - EXPECT_EQ(c.getRow(), e.getRow()); + // Check add operation after resize + EXPECT_OK(queue.Add(a)); + ASSERT_EQ(3, queue.size()); + // Check pop operation after resize + EXPECT_OK(queue.PopFront(&d)); + EXPECT_EQ(b.getRow(), d.getRow()); + EXPECT_OK(queue.PopFront(&d)); + EXPECT_EQ(c.getRow(), d.getRow()); ASSERT_EQ(1, queue.size()); queue.Reset(); ASSERT_EQ(0, queue.size());