forked from mindspore-Ecosystem/mindspore
!779 Fix pylint warning for samplers.py
Merge pull request !779 from JunhanHu/sampler_pylint
This commit is contained in:
commit
fe9000812d
|
@ -19,8 +19,8 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler.
|
|||
User can also define custom sampler by extending from Sampler class.
|
||||
"""
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
import numpy as np
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
|
||||
class Sampler:
|
||||
|
@ -137,6 +137,7 @@ class DistributedSampler(BuiltinSampler):
|
|||
self.shard_id = shard_id
|
||||
self.shuffle = shuffle
|
||||
self.seed = 0
|
||||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
|
||||
|
@ -182,6 +183,7 @@ class PKSampler(BuiltinSampler):
|
|||
self.num_val = num_val
|
||||
self.shuffle = shuffle
|
||||
self.class_column = class_column # work for minddataset
|
||||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
return cde.PKSampler(self.num_val, self.shuffle)
|
||||
|
@ -192,6 +194,7 @@ class PKSampler(BuiltinSampler):
|
|||
but got class_column={}".format(class_column))
|
||||
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
|
||||
|
||||
|
||||
class RandomSampler(BuiltinSampler):
|
||||
"""
|
||||
Samples the elements randomly.
|
||||
|
@ -225,6 +228,7 @@ class RandomSampler(BuiltinSampler):
|
|||
|
||||
self.replacement = replacement
|
||||
self.num_samples = num_samples
|
||||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
# If num_samples is not specified, then call constructor #2
|
||||
|
@ -275,6 +279,7 @@ class SubsetRandomSampler(BuiltinSampler):
|
|||
indices = [indices]
|
||||
|
||||
self.indices = indices
|
||||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
return cde.SubsetRandomSampler(self.indices)
|
||||
|
@ -322,6 +327,7 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
self.weights = weights
|
||||
self.num_samples = num_samples
|
||||
self.replacement = replacement
|
||||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)
|
||||
|
|
Loading…
Reference in New Issue