forked from mindspore-Ecosystem/mindspore
!779 Fix pylint warning for samplers.py
Merge pull request !779 from JunhanHu/sampler_pylint
This commit is contained in:
commit
fe9000812d
|
@ -19,8 +19,8 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler.
|
||||||
User can also define custom sampler by extending from Sampler class.
|
User can also define custom sampler by extending from Sampler class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import mindspore._c_dataengine as cde
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import mindspore._c_dataengine as cde
|
||||||
|
|
||||||
|
|
||||||
class Sampler:
|
class Sampler:
|
||||||
|
@ -137,6 +137,7 @@ class DistributedSampler(BuiltinSampler):
|
||||||
self.shard_id = shard_id
|
self.shard_id = shard_id
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.seed = 0
|
self.seed = 0
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
|
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
|
||||||
|
@ -182,6 +183,7 @@ class PKSampler(BuiltinSampler):
|
||||||
self.num_val = num_val
|
self.num_val = num_val
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.class_column = class_column # work for minddataset
|
self.class_column = class_column # work for minddataset
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
return cde.PKSampler(self.num_val, self.shuffle)
|
return cde.PKSampler(self.num_val, self.shuffle)
|
||||||
|
@ -192,6 +194,7 @@ class PKSampler(BuiltinSampler):
|
||||||
but got class_column={}".format(class_column))
|
but got class_column={}".format(class_column))
|
||||||
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
|
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
|
||||||
|
|
||||||
|
|
||||||
class RandomSampler(BuiltinSampler):
|
class RandomSampler(BuiltinSampler):
|
||||||
"""
|
"""
|
||||||
Samples the elements randomly.
|
Samples the elements randomly.
|
||||||
|
@ -225,6 +228,7 @@ class RandomSampler(BuiltinSampler):
|
||||||
|
|
||||||
self.replacement = replacement
|
self.replacement = replacement
|
||||||
self.num_samples = num_samples
|
self.num_samples = num_samples
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
# If num_samples is not specified, then call constructor #2
|
# If num_samples is not specified, then call constructor #2
|
||||||
|
@ -275,6 +279,7 @@ class SubsetRandomSampler(BuiltinSampler):
|
||||||
indices = [indices]
|
indices = [indices]
|
||||||
|
|
||||||
self.indices = indices
|
self.indices = indices
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
return cde.SubsetRandomSampler(self.indices)
|
return cde.SubsetRandomSampler(self.indices)
|
||||||
|
@ -322,6 +327,7 @@ class WeightedRandomSampler(BuiltinSampler):
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
self.num_samples = num_samples
|
self.num_samples = num_samples
|
||||||
self.replacement = replacement
|
self.replacement = replacement
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)
|
return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)
|
||||||
|
|
Loading…
Reference in New Issue