[MD] Fix bugs in WeightedRandomSampler & SubsetRandomSampler

This commit is contained in:
luoyang 2020-10-24 19:28:35 +08:00
parent a6075cc73b
commit f0e976dbda
4 changed files with 45 additions and 1 deletions

View File

@ -299,6 +299,24 @@ WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights,
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
bool WeightedRandomSamplerObj::ValidateParams() {
if (weights_.empty()) {
MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not be empty";
return false;
}
int32_t zero_elem = 0;
for (int32_t i = 0; i < weights_.size(); ++i) {
if (weights_[i] < 0) {
MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not contain negative number, got: " << weights_[i];
return false;
}
if (weights_[i] == 0.0) {
zero_elem++;
}
}
if (zero_elem == weights_.size()) {
MS_LOG(ERROR) << "WeightedRandomSampler: elements of weights vector must not be all zero";
return false;
}
if (num_samples_ < 0) {
MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_;
return false;

View File

@ -103,7 +103,7 @@ Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe
RETURN_STATUS_UNEXPECTED(err_msg);
}
int64_t sampled_id = indices_[sample_id_];
int64_t sampled_id = ((indices_[sample_id_] % num_rows_) + num_rows_) % num_rows_;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}

View File

@ -585,6 +585,15 @@ class WeightedRandomSampler(BuiltinSampler):
if not isinstance(weights, list):
weights = [weights]
if weights == []:
raise ValueError("weights size should not be 0")
if list(filter(lambda x: x < 0, weights)):
raise ValueError("weights should not contain negative numbers")
if list(filter(lambda x: x == 0, weights)) == weights:
raise ValueError("elements of weights should not be all zero")
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "

View File

@ -99,3 +99,20 @@ TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) {
EXPECT_TRUE(indices.empty());
EXPECT_NE(sampl2->Build(), nullptr);
}
TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) {
// weights is empty
std::vector<double> weights1 = {};
std::shared_ptr<SamplerObj> sampl1 = WeightedRandomSampler(weights1);
EXPECT_EQ(sampl1, nullptr);
// weights has negative number
std::vector<double> weights2 = {0.5, 0.2, -0.4};
std::shared_ptr<SamplerObj> sampl2 = WeightedRandomSampler(weights2);
EXPECT_EQ(sampl2, nullptr);
// weights elements are all zero
std::vector<double> weights3 = {0, 0, 0};
std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights3);
EXPECT_EQ(sampl3, nullptr);
}