From 8f16cff12e42c3f6db40ea1afe8732149fdd000d Mon Sep 17 00:00:00 2001 From: yanghaitao Date: Tue, 9 Jun 2020 14:27:08 +0800 Subject: [PATCH] add para check for sampler --- mindspore/dataset/engine/samplers.py | 25 +++++++++++++++++++++++++ tests/ut/python/dataset/test_sampler.py | 1 - 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index b593b193e8a..ca633745a8a 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -218,6 +218,11 @@ class DistributedSampler(BuiltinSampler): if not isinstance(shuffle, bool): raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle)) + if num_samples is not None: + if num_samples <= 0: + raise ValueError("num_samples should be a positive integer " + "value, but got num_samples={}".format(num_samples)) + self.num_shards = num_shards self.shard_id = shard_id self.shuffle = shuffle @@ -282,6 +287,11 @@ class PKSampler(BuiltinSampler): if not isinstance(shuffle, bool): raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle)) + if num_samples is not None: + if num_samples <= 0: + raise ValueError("num_samples should be a positive integer " + "value, but got num_samples={}".format(num_samples)) + self.num_val = num_val self.shuffle = shuffle self.class_column = class_column # work for minddataset @@ -385,6 +395,16 @@ class SequentialSampler(BuiltinSampler): """ def __init__(self, start_index=None, num_samples=None): + if num_samples is not None: + if num_samples <= 0: + raise ValueError("num_samples should be a positive integer " + "value, but got num_samples={}".format(num_samples)) + + if start_index is not None: + if start_index < 0: + raise ValueError("start_index should be a positive integer " + "value or 0, but got start_index={}".format(start_index)) + self.start_index = start_index super().__init__(num_samples) @@ -430,6 +450,11 @@ class SubsetRandomSampler(BuiltinSampler): """ def __init__(self, indices, num_samples=None): + if num_samples is not None: + if num_samples <= 0: + raise ValueError("num_samples should be a positive integer " + "value, but got num_samples={}".format(num_samples)) + if not isinstance(indices, list): indices = [indices] diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index 381b6dafe7a..a7ec89c2092 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -43,7 +43,6 @@ def test_sequential_sampler(print_res=False): assert test_config(num_samples=3, num_repeats=None) == [0, 1, 2] assert test_config(num_samples=None, num_repeats=2) == [0, 1, 2, 3, 4] * 2 - assert test_config(num_samples=0, num_repeats=2) == [0, 1, 2, 3, 4] * 2 assert test_config(num_samples=4, num_repeats=2) == [0, 1, 2, 3] * 2