forked from mindspore-Ecosystem/mindspore
[MD] Fix bugs in WeightedRandomSampler & SubsetRandomSampler
This commit is contained in:
parent
a6075cc73b
commit
f0e976dbda
|
@ -299,6 +299,24 @@ WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights,
|
|||
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
|
||||
|
||||
bool WeightedRandomSamplerObj::ValidateParams() {
|
||||
if (weights_.empty()) {
|
||||
MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not be empty";
|
||||
return false;
|
||||
}
|
||||
int32_t zero_elem = 0;
|
||||
for (int32_t i = 0; i < weights_.size(); ++i) {
|
||||
if (weights_[i] < 0) {
|
||||
MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not contain negative number, got: " << weights_[i];
|
||||
return false;
|
||||
}
|
||||
if (weights_[i] == 0.0) {
|
||||
zero_elem++;
|
||||
}
|
||||
}
|
||||
if (zero_elem == weights_.size()) {
|
||||
MS_LOG(ERROR) << "WeightedRandomSampler: elements of weights vector must not be all zero";
|
||||
return false;
|
||||
}
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_;
|
||||
return false;
|
||||
|
|
|
@ -103,7 +103,7 @@ Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe
|
|||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
int64_t sampled_id = indices_[sample_id_];
|
||||
int64_t sampled_id = ((indices_[sample_id_] % num_rows_) + num_rows_) % num_rows_;
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
|
|
@ -585,6 +585,15 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
if not isinstance(weights, list):
|
||||
weights = [weights]
|
||||
|
||||
if weights == []:
|
||||
raise ValueError("weights size should not be 0")
|
||||
|
||||
if list(filter(lambda x: x < 0, weights)):
|
||||
raise ValueError("weights should not contain negative numbers")
|
||||
|
||||
if list(filter(lambda x: x == 0, weights)) == weights:
|
||||
raise ValueError("elements of weights should not be all zero")
|
||||
|
||||
if num_samples is not None:
|
||||
if num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
|
|
|
@ -99,3 +99,20 @@ TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) {
|
|||
EXPECT_TRUE(indices.empty());
|
||||
EXPECT_NE(sampl2->Build(), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) {
|
||||
// weights is empty
|
||||
std::vector<double> weights1 = {};
|
||||
std::shared_ptr<SamplerObj> sampl1 = WeightedRandomSampler(weights1);
|
||||
EXPECT_EQ(sampl1, nullptr);
|
||||
|
||||
// weights has negative number
|
||||
std::vector<double> weights2 = {0.5, 0.2, -0.4};
|
||||
std::shared_ptr<SamplerObj> sampl2 = WeightedRandomSampler(weights2);
|
||||
EXPECT_EQ(sampl2, nullptr);
|
||||
|
||||
// weights elements are all zero
|
||||
std::vector<double> weights3 = {0, 0, 0};
|
||||
std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights3);
|
||||
EXPECT_EQ(sampl3, nullptr);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue