diff --git a/mindspore/ccsrc/minddata/dataset/api/samplers.cc b/mindspore/ccsrc/minddata/dataset/api/samplers.cc index ed4274df373..a56add0dc2d 100644 --- a/mindspore/ccsrc/minddata/dataset/api/samplers.cc +++ b/mindspore/ccsrc/minddata/dataset/api/samplers.cc @@ -299,6 +299,24 @@ WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector weights, : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} bool WeightedRandomSamplerObj::ValidateParams() { + if (weights_.empty()) { + MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not be empty"; + return false; + } + int32_t zero_elem = 0; + for (int32_t i = 0; i < weights_.size(); ++i) { + if (weights_[i] < 0) { + MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not contain negative number, got: " << weights_[i]; + return false; + } + if (weights_[i] == 0.0) { + zero_elem++; + } + } + if (zero_elem == weights_.size()) { + MS_LOG(ERROR) << "WeightedRandomSampler: elements of weights vector must not be all zero"; + return false; + } if (num_samples_ < 0) { MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_; return false; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc index 0d58f23d5e5..b1f251a8d5b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc @@ -103,7 +103,7 @@ Status SubsetRandomSampler::GetNextSample(std::unique_ptr *out_buffe RETURN_STATUS_UNEXPECTED(err_msg); } - int64_t sampled_id = indices_[sample_id_]; + int64_t sampled_id = ((indices_[sample_id_] % num_rows_) + num_rows_) % num_rows_; if (HasChildSampler()) { RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); } diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index c5039fb15cf..6ff1c3413ba 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -585,6 +585,15 @@ class WeightedRandomSampler(BuiltinSampler): if not isinstance(weights, list): weights = [weights] + if weights == []: + raise ValueError("weights size should not be 0") + + if list(filter(lambda x: x < 0, weights)): + raise ValueError("weights should not contain negative numbers") + + if list(filter(lambda x: x == 0, weights)) == weights: + raise ValueError("elements of weights should not be all zero") + if num_samples is not None: if num_samples <= 0: raise ValueError("num_samples should be a positive integer " diff --git a/tests/ut/cpp/dataset/c_api_samplers_test.cc b/tests/ut/cpp/dataset/c_api_samplers_test.cc index 45605663f3b..47b8a43170c 100644 --- a/tests/ut/cpp/dataset/c_api_samplers_test.cc +++ b/tests/ut/cpp/dataset/c_api_samplers_test.cc @@ -99,3 +99,20 @@ TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) { EXPECT_TRUE(indices.empty()); EXPECT_NE(sampl2->Build(), nullptr); } + +TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) { + // weights is empty + std::vector weights1 = {}; + std::shared_ptr sampl1 = WeightedRandomSampler(weights1); + EXPECT_EQ(sampl1, nullptr); + + // weights has negative number + std::vector weights2 = {0.5, 0.2, -0.4}; + std::shared_ptr sampl2 = WeightedRandomSampler(weights2); + EXPECT_EQ(sampl2, nullptr); + + // weights elements are all zero + std::vector weights3 = {0, 0, 0}; + std::shared_ptr sampl3 = WeightedRandomSampler(weights3); + EXPECT_EQ(sampl3, nullptr); +}