forked from mindspore-Ecosystem/mindspore
!8072 Revert changes to weighted_random_sampler in PR7866
Merge pull request !8072 from luoyang/pylint
This commit is contained in:
commit
7e43eaf5e8
|
@ -19,6 +19,7 @@ SequentialSampler, SubsetRandomSampler, and WeightedRandomSampler.
|
|||
Users can also define a custom sampler by extending from the Sampler class.
|
||||
"""
|
||||
|
||||
import numbers
|
||||
import numpy as np
|
||||
import mindspore._c_dataengine as cde
|
||||
import mindspore.dataset as ds
|
||||
|
@ -591,6 +592,20 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
if not isinstance(weights, list):
|
||||
weights = [weights]
|
||||
|
||||
for ind, w in enumerate(weights):
|
||||
if not isinstance(w, numbers.Number):
|
||||
raise TypeError("type of weights element should be number, "
|
||||
"but got w[{}]={}, type={}".format(ind, w, type(w)))
|
||||
|
||||
if weights == []:
|
||||
raise ValueError("weights size should not be 0")
|
||||
|
||||
if list(filter(lambda x: x < 0, weights)) != []:
|
||||
raise ValueError("weights should not contain negative numbers")
|
||||
|
||||
if list(filter(lambda x: x == 0, weights)) == weights:
|
||||
raise ValueError("elements of weights should not be all zero")
|
||||
|
||||
if num_samples is not None:
|
||||
if num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
|
@ -382,6 +383,35 @@ def test_weighted_random_sampler():
|
|||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 11
|
||||
|
||||
def test_weighted_random_sampler_exception():
|
||||
"""
|
||||
Test error cases for WeightedRandomSampler
|
||||
"""
|
||||
logger.info("Test error cases for WeightedRandomSampler")
|
||||
error_msg_1 = "type of weights element should be number"
|
||||
with pytest.raises(TypeError, match=error_msg_1):
|
||||
weights = ""
|
||||
ds.WeightedRandomSampler(weights)
|
||||
|
||||
error_msg_2 = "type of weights element should be number"
|
||||
with pytest.raises(TypeError, match=error_msg_2):
|
||||
weights = (0.9, 0.8, 1.1)
|
||||
ds.WeightedRandomSampler(weights)
|
||||
|
||||
error_msg_3 = "weights size should not be 0"
|
||||
with pytest.raises(ValueError, match=error_msg_3):
|
||||
weights = []
|
||||
ds.WeightedRandomSampler(weights)
|
||||
|
||||
error_msg_4 = "weights should not contain negative numbers"
|
||||
with pytest.raises(ValueError, match=error_msg_4):
|
||||
weights = [1.0, 0.1, 0.02, 0.3, -0.4]
|
||||
ds.WeightedRandomSampler(weights)
|
||||
|
||||
error_msg_5 = "elements of weights should not be all zero"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
weights = [0, 0, 0, 0, 0]
|
||||
ds.WeightedRandomSampler(weights)
|
||||
|
||||
def test_imagefolder_rename():
|
||||
logger.info("Test Case rename")
|
||||
|
@ -465,6 +495,9 @@ if __name__ == '__main__':
|
|||
test_weighted_random_sampler()
|
||||
logger.info('test_weighted_random_sampler Ended.\n')
|
||||
|
||||
test_weighted_random_sampler_exception()
|
||||
logger.info('test_weighted_random_sampler_exception Ended.\n')
|
||||
|
||||
test_imagefolder_numshards()
|
||||
logger.info('test_imagefolder_numshards Ended.\n')
|
||||
|
||||
|
|
Loading…
Reference in New Issue