diff --git a/mindspore/ccsrc/dataset/util/circular_pool.cc b/mindspore/ccsrc/dataset/util/circular_pool.cc index 0c68dab81bf..42cccd87ed3 100644 --- a/mindspore/ccsrc/dataset/util/circular_pool.cc +++ b/mindspore/ccsrc/dataset/util/circular_pool.cc @@ -88,6 +88,9 @@ Status CircularPool::Allocate(size_t n, void **p) { while (cirIt.has_next()) { auto it = cirIt.Next(); Arena *ba = it->get(); + if (ba->get_max_size() < n) { + return Status(StatusCode::kOutOfMemory); + } // If we are asked to move forward the tail if (move_tail) { Arena *expected = cirIt.cur_tail_; diff --git a/mindspore/ccsrc/dataset/util/queue.h b/mindspore/ccsrc/dataset/util/queue.h index 7fca93d944a..52309962d58 100644 --- a/mindspore/ccsrc/dataset/util/queue.h +++ b/mindspore/ccsrc/dataset/util/queue.h @@ -182,6 +182,9 @@ class Queue { arr_[k].~T(); } } + for (uint64_t i = 0; i < sz_; i++) { + std::allocator_traits>::construct(alloc_, &(arr_[i])); + } empty_cv_.ResetIntrpState(); full_cv_.ResetIntrpState(); head_ = 0; diff --git a/tests/ut/cpp/dataset/queue_test.cc b/tests/ut/cpp/dataset/queue_test.cc index 578405e5370..05c80ea50ff 100644 --- a/tests/ut/cpp/dataset/queue_test.cc +++ b/tests/ut/cpp/dataset/queue_test.cc @@ -19,6 +19,8 @@ #include "dataset/util/task_manager.h" #include "dataset/util/queue.h" #include +#include +#include #include "utils/log_adapter.h" using namespace mindspore::dataset; @@ -39,7 +41,7 @@ class RefCount { public: RefCount() : v_(nullptr) {} explicit RefCount(int x) : v_(std::make_shared(x)) {} - explicit RefCount(const RefCount &o) : v_(o.v_) {} + RefCount(const RefCount &o) : v_(o.v_) {} ~RefCount() { MS_LOG(DEBUG) << "Destructor of RefCount called" << std::endl; gRefCountDestructorCalled++; @@ -167,3 +169,70 @@ TEST_F(MindDataTestQueue, Test6) { MS_LOG(INFO) << "Popped value " << *pepped_value << " from queue index " << chosen_queue_index; ASSERT_EQ(*pepped_value, 99); } +using namespace std::chrono; +template +void Perf(int n, int p, std::string name) { + auto payload = std::vector(n, PayloadType(p)); + auto queue = QueueType(n); + auto t0 = high_resolution_clock::now(); + auto check = 0; + for (int i = 0; i < queue.capacity(); i++) { + queue.Add(PayloadType(p)); + } + check = queue.size(); + for (int i = 0; i < queue.capacity(); i++) { + queue.PopFront(&payload[i]); + } + auto t1 = high_resolution_clock::now(); + std::cout << name << " queue filled size: " << queue.size() << " " << check << std::endl; + auto t2 = high_resolution_clock::now(); + for (int i = 0; i < queue.capacity(); i++) { + queue.Add(PayloadType(p)); + } + check = queue.size(); + for (int i = 0; i < queue.capacity(); i++) { + queue.PopFront(&payload[i]); + } + auto t3 = high_resolution_clock::now(); + auto d = duration_cast(t3 - t2 + t1 - t0).count(); + std::cout << name << " queue emptied size: " << queue.size() << " " << check << std::endl; + std::cout << name << " " + << " ran in " << d << "ms" << std::endl; +} + +template +void Fuzz(int n, int p, std::string name) { + std::mt19937 gen(1); + auto payload = std::vector(n, PayloadType(p)); + auto queue = QueueType(n); + auto dist = std::uniform_int_distribution(0, 2); + std::cout << "###" << std::endl; + for (auto i = 0; i < n; i++) { + auto v = dist(gen); + if (v == 0 && queue.size() < n - 1) { + queue.Add(std::move(payload[i])); + } + if (v == 1 && queue.size() > 0) { + queue.PopFront(&payload[i]); + } else { + queue.Reset(); + } + } + std::cout << name << " fuzz ran " << queue.size() << std::endl; +} +TEST_F(MindDataTestQueue, TestPerf) { + try { + int kSz = 1000000; + // std::cout << "enter size" << std::endl; + // std::cin >> kSz; + Perf>, std::vector>(kSz, 1, "old queue, vector of size 1"); + } catch (const std::exception &e) { + std::cout << e.what() << std::endl; + } + + std::cout << "Test Reset" << std::endl; + std::cout << "Enter fuzz size" << std::endl; + int fs = 1000; +// std::cin >> fs; + Fuzz>, std::vector>(fs, 1, "New queue"); +}