forked from mindspore-Ecosystem/mindspore
Rename Sampler Reset
This commit is contained in:
parent
a1f194c971
commit
dcfdff6021
|
@ -409,7 +409,7 @@ void CelebAOp::Print(std::ostream &out, bool show_all) const {
|
|||
|
||||
// Reset Sampler and wakeup Master thread (functor)
|
||||
Status CelebAOp::Reset() {
|
||||
RETURN_IF_NOT_OK(sampler_->Reset());
|
||||
RETURN_IF_NOT_OK(sampler_->ResetSampler());
|
||||
wp_.Set(); // wake up master thread after reset is done
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -241,7 +241,7 @@ void CifarOp::Print(std::ostream &out, bool show_all) const {
|
|||
|
||||
// Reset Sampler and wakeup Master thread (functor)
|
||||
Status CifarOp::Reset() {
|
||||
RETURN_IF_NOT_OK(sampler_->Reset());
|
||||
RETURN_IF_NOT_OK(sampler_->ResetSampler());
|
||||
row_cnt_ = 0;
|
||||
wp_.Set(); // wake up master thread after reset is done
|
||||
return Status::OK();
|
||||
|
|
|
@ -207,7 +207,7 @@ void CocoOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
|
||||
Status CocoOp::Reset() {
|
||||
RETURN_IF_NOT_OK(sampler_->Reset());
|
||||
RETURN_IF_NOT_OK(sampler_->ResetSampler());
|
||||
row_cnt_ = 0;
|
||||
wp_.Set();
|
||||
return Status::OK();
|
||||
|
|
|
@ -252,7 +252,7 @@ void ImageFolderOp::Print(std::ostream &out, bool show_all) const {
|
|||
|
||||
// Reset Sampler and wakeup Master thread (functor)
|
||||
Status ImageFolderOp::Reset() {
|
||||
RETURN_IF_NOT_OK(sampler_->Reset());
|
||||
RETURN_IF_NOT_OK(sampler_->ResetSampler());
|
||||
row_cnt_ = 0;
|
||||
wp_.Set(); // wake up master thread after reset is done
|
||||
return Status::OK();
|
||||
|
|
|
@ -241,7 +241,7 @@ void ManifestOp::Print(std::ostream &out, bool show_all) const {
|
|||
|
||||
// Reset Sampler and wakeup Master thread (functor)
|
||||
Status ManifestOp::Reset() {
|
||||
RETURN_IF_NOT_OK(sampler_->Reset());
|
||||
RETURN_IF_NOT_OK(sampler_->ResetSampler());
|
||||
row_cnt_ = 0;
|
||||
wp_.Set(); // wake up master thread after reset is done
|
||||
return Status::OK();
|
||||
|
|
|
@ -204,7 +204,7 @@ void MnistOp::Print(std::ostream &out, bool show_all) const {
|
|||
|
||||
// Reset Sampler and wakeup Master thread (functor)
|
||||
Status MnistOp::Reset() {
|
||||
RETURN_IF_NOT_OK(sampler_->Reset());
|
||||
RETURN_IF_NOT_OK(sampler_->ResetSampler());
|
||||
row_cnt_ = 0;
|
||||
wp_.Set(); // wake up master thread after reset is done
|
||||
return Status::OK();
|
||||
|
|
|
@ -89,7 +89,7 @@ Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DistributedSampler::Reset() {
|
||||
Status DistributedSampler::ResetSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late");
|
||||
cnt_ = 0;
|
||||
|
||||
|
@ -100,7 +100,7 @@ Status DistributedSampler::Reset() {
|
|||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -47,7 +47,7 @@ class DistributedSampler : public Sampler {
|
|||
|
||||
// for next epoch of sampleIds
|
||||
// @return - The error code return
|
||||
Status Reset() override;
|
||||
Status ResetSampler() override;
|
||||
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
|
|
|
@ -94,13 +94,13 @@ Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PKSampler::Reset() {
|
||||
Status PKSampler::ResetSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
|
||||
next_id_ = 0;
|
||||
rnd_.seed(seed_++);
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -54,7 +54,7 @@ class PKSampler : public Sampler { // NOT YET FINISHED
|
|||
|
||||
// for next epoch of sampleIds
|
||||
// @return - The error code return
|
||||
Status Reset() override;
|
||||
Status ResetSampler() override;
|
||||
|
||||
private:
|
||||
bool shuffle_;
|
||||
|
|
|
@ -84,7 +84,7 @@ Status PythonSampler::InitSampler() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PythonSampler::Reset() {
|
||||
Status PythonSampler::ResetSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch");
|
||||
need_to_reset_ = false;
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
|
@ -98,7 +98,7 @@ Status PythonSampler::Reset() {
|
|||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -42,7 +42,7 @@ class PythonSampler : public Sampler {
|
|||
|
||||
// for next epoch of sampleIds
|
||||
// @return - The error code return
|
||||
Status Reset() override;
|
||||
Status ResetSampler() override;
|
||||
|
||||
// Op calls this to get next Buffer that contains all the sampleIds
|
||||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
|
||||
|
|
|
@ -91,7 +91,7 @@ Status RandomSampler::InitSampler() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RandomSampler::Reset() {
|
||||
Status RandomSampler::ResetSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
|
||||
next_id_ = 0;
|
||||
|
||||
|
@ -106,7 +106,7 @@ Status RandomSampler::Reset() {
|
|||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -48,7 +48,7 @@ class RandomSampler : public Sampler {
|
|||
|
||||
// for next epoch of sampleIds
|
||||
// @return - The error code return
|
||||
Status Reset() override;
|
||||
Status ResetSampler() override;
|
||||
|
||||
virtual void Print(std::ostream &out, bool show_all) const;
|
||||
|
||||
|
|
|
@ -113,7 +113,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
|
|||
RETURN_IF_NOT_OK(GetNextSample(&db));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received");
|
||||
// Reset Sampler since this is the end of the epoch
|
||||
RETURN_IF_NOT_OK(Reset());
|
||||
RETURN_IF_NOT_OK(ResetSampler());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -62,6 +62,8 @@ class Sampler {
|
|||
// @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit Sampler(int64_t num_samples, int64_t samples_per_buffer);
|
||||
|
||||
Sampler(const Sampler &s) : Sampler(s.num_samples_, s.samples_per_buffer_) {}
|
||||
|
||||
// default destructor
|
||||
~Sampler() = default;
|
||||
|
||||
|
@ -77,7 +79,7 @@ class Sampler {
|
|||
|
||||
// for next epoch of sampleIds
|
||||
// @return - The error code return
|
||||
virtual Status Reset() = 0;
|
||||
virtual Status ResetSampler() = 0;
|
||||
|
||||
// first handshake between leaf source op and Sampler. This func will determine the amount of data
|
||||
// in the dataset that we can sample from.
|
||||
|
@ -109,8 +111,16 @@ class Sampler {
|
|||
// @return - The error code returned.
|
||||
Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);
|
||||
|
||||
// A print method typically used for debugging
|
||||
// @param out - The output stream to write output to
|
||||
// @param show_all - A bool to control if you want to show all info or just a summary
|
||||
virtual void Print(std::ostream &out, bool show_all) const;
|
||||
|
||||
// << Stream output operator overload
|
||||
// @notes This allows you to write the debug print info using stream operators
|
||||
// @param out - reference to the output stream being overloaded
|
||||
// @param sampler - reference to teh sampler to print
|
||||
// @return - the output stream must be returned
|
||||
friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) {
|
||||
sampler.Print(out, false);
|
||||
return out;
|
||||
|
|
|
@ -77,13 +77,13 @@ Status SequentialSampler::InitSampler() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SequentialSampler::Reset() {
|
||||
Status SequentialSampler::ResetSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late");
|
||||
current_id_ = start_index_;
|
||||
id_count_ = 0;
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -41,7 +41,7 @@ class SequentialSampler : public Sampler {
|
|||
|
||||
// for next epoch of sampleIds
|
||||
// @return - The error code return
|
||||
Status Reset() override;
|
||||
Status ResetSampler() override;
|
||||
|
||||
// Op calls this to get next Buffer that contains all the sampleIds
|
||||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
|
||||
|
|
|
@ -55,7 +55,7 @@ Status SubsetRandomSampler::InitSampler() {
|
|||
}
|
||||
|
||||
// Reset the internal variable to the initial state.
|
||||
Status SubsetRandomSampler::Reset() {
|
||||
Status SubsetRandomSampler::ResetSampler() {
|
||||
// Reset the internal counters.
|
||||
sample_id_ = 0;
|
||||
buffer_id_ = 0;
|
||||
|
@ -65,7 +65,7 @@ Status SubsetRandomSampler::Reset() {
|
|||
std::shuffle(indices_.begin(), indices_.end(), rand_gen_);
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -44,7 +44,7 @@ class SubsetRandomSampler : public Sampler {
|
|||
|
||||
// Reset the internal variable to the initial state and reshuffle the indices.
|
||||
// @return Status
|
||||
Status Reset() override;
|
||||
Status ResetSampler() override;
|
||||
|
||||
// Get the sample ids.
|
||||
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
|
||||
|
|
|
@ -77,7 +77,7 @@ void WeightedRandomSampler::InitOnePassSampling() {
|
|||
}
|
||||
|
||||
// Reset the internal variable to the initial state and reshuffle the indices.
|
||||
Status WeightedRandomSampler::Reset() {
|
||||
Status WeightedRandomSampler::ResetSampler() {
|
||||
sample_id_ = 0;
|
||||
buffer_id_ = 0;
|
||||
rand_gen_.seed(GetSeed());
|
||||
|
@ -88,7 +88,7 @@ Status WeightedRandomSampler::Reset() {
|
|||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -46,7 +46,7 @@ class WeightedRandomSampler : public Sampler {
|
|||
Status InitSampler() override;
|
||||
|
||||
// Reset the internal variable to the initial state and reshuffle the indices.
|
||||
Status Reset() override;
|
||||
Status ResetSampler() override;
|
||||
|
||||
// Get the sample ids.
|
||||
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
|
||||
|
|
|
@ -177,7 +177,7 @@ void VOCOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
|
||||
Status VOCOp::Reset() {
|
||||
RETURN_IF_NOT_OK(sampler_->Reset());
|
||||
RETURN_IF_NOT_OK(sampler_->ResetSampler());
|
||||
row_cnt_ = 0;
|
||||
wp_.Set();
|
||||
return Status::OK();
|
||||
|
|
|
@ -94,7 +94,7 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
|
|||
sampler->GetNextSample(&db);
|
||||
db->GetTensor(&tensor, 0, 0);
|
||||
EXPECT_TRUE((*tensor) == (*label2));
|
||||
sampler->Reset();
|
||||
sampler->ResetSampler();
|
||||
sampler->GetNextSample(&db);
|
||||
db->GetTensor(&tensor, 0, 0);
|
||||
EXPECT_TRUE((*tensor) == (*label1));
|
||||
|
|
|
@ -123,7 +123,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
|
|||
ASSERT_NE(in_set.find(out[i]), in_set.end());
|
||||
}
|
||||
|
||||
sampler.Reset();
|
||||
sampler.ResetSampler();
|
||||
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), false);
|
||||
|
|
|
@ -214,7 +214,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
|
|||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
|
||||
m_sampler.Reset();
|
||||
m_sampler.ResetSampler();
|
||||
out.clear();
|
||||
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
|
@ -259,7 +259,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
|
|||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
|
||||
m_sampler.Reset();
|
||||
m_sampler.ResetSampler();
|
||||
out.clear();
|
||||
freq.clear();
|
||||
freq.resize(total_samples, 0);
|
||||
|
|
Loading…
Reference in New Issue