!8072 Revert changes to weighted_random_sampler in PR7866

Merge pull request !8072 from luoyang/pylint
This commit is contained in:
mindspore-ci-bot 2020-11-02 15:36:45 +08:00 committed by Gitee
commit 7e43eaf5e8
2 changed files with 48 additions and 0 deletions

View File

@ -19,6 +19,7 @@ SequentialSampler, SubsetRandomSampler, and WeightedRandomSampler.
Users can also define a custom sampler by extending from the Sampler class.
"""
import numbers
import numpy as np
import mindspore._c_dataengine as cde
import mindspore.dataset as ds
@ -591,6 +592,20 @@ class WeightedRandomSampler(BuiltinSampler):
if not isinstance(weights, list):
weights = [weights]
for ind, w in enumerate(weights):
if not isinstance(w, numbers.Number):
raise TypeError("type of weights element should be number, "
"but got w[{}]={}, type={}".format(ind, w, type(w)))
if weights == []:
raise ValueError("weights size should not be 0")
if list(filter(lambda x: x < 0, weights)) != []:
raise ValueError("weights should not contain negative numbers")
if list(filter(lambda x: x == 0, weights)) == weights:
raise ValueError("elements of weights should not be all zero")
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
@ -382,6 +383,35 @@ def test_weighted_random_sampler():
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 11
def test_weighted_random_sampler_exception():
"""
Test error cases for WeightedRandomSampler
"""
logger.info("Test error cases for WeightedRandomSampler")
error_msg_1 = "type of weights element should be number"
with pytest.raises(TypeError, match=error_msg_1):
weights = ""
ds.WeightedRandomSampler(weights)
error_msg_2 = "type of weights element should be number"
with pytest.raises(TypeError, match=error_msg_2):
weights = (0.9, 0.8, 1.1)
ds.WeightedRandomSampler(weights)
error_msg_3 = "weights size should not be 0"
with pytest.raises(ValueError, match=error_msg_3):
weights = []
ds.WeightedRandomSampler(weights)
error_msg_4 = "weights should not contain negative numbers"
with pytest.raises(ValueError, match=error_msg_4):
weights = [1.0, 0.1, 0.02, 0.3, -0.4]
ds.WeightedRandomSampler(weights)
error_msg_5 = "elements of weights should not be all zero"
with pytest.raises(ValueError, match=error_msg_5):
weights = [0, 0, 0, 0, 0]
ds.WeightedRandomSampler(weights)
def test_imagefolder_rename():
logger.info("Test Case rename")
@ -465,6 +495,9 @@ if __name__ == '__main__':
test_weighted_random_sampler()
logger.info('test_weighted_random_sampler Ended.\n')
test_weighted_random_sampler_exception()
logger.info('test_weighted_random_sampler_exception Ended.\n')
test_imagefolder_numshards()
logger.info('test_imagefolder_numshards Ended.\n')