!779 Fix pylint warning for samplers.py

Merge pull request !779 from JunhanHu/sampler_pylint
This commit is contained in:
mindspore-ci-bot 2020-04-28 01:55:26 +08:00 committed by Gitee
commit fe9000812d
1 changed files with 7 additions and 1 deletions

View File

@ -19,8 +19,8 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler.
User can also define custom sampler by extending from Sampler class. User can also define custom sampler by extending from Sampler class.
""" """
import mindspore._c_dataengine as cde
import numpy as np import numpy as np
import mindspore._c_dataengine as cde
class Sampler: class Sampler:
@ -137,6 +137,7 @@ class DistributedSampler(BuiltinSampler):
self.shard_id = shard_id self.shard_id = shard_id
self.shuffle = shuffle self.shuffle = shuffle
self.seed = 0 self.seed = 0
super().__init__()
def create(self): def create(self):
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
@ -182,6 +183,7 @@ class PKSampler(BuiltinSampler):
self.num_val = num_val self.num_val = num_val
self.shuffle = shuffle self.shuffle = shuffle
self.class_column = class_column # work for minddataset self.class_column = class_column # work for minddataset
super().__init__()
def create(self): def create(self):
return cde.PKSampler(self.num_val, self.shuffle) return cde.PKSampler(self.num_val, self.shuffle)
@ -192,6 +194,7 @@ class PKSampler(BuiltinSampler):
but got class_column={}".format(class_column)) but got class_column={}".format(class_column))
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
class RandomSampler(BuiltinSampler): class RandomSampler(BuiltinSampler):
""" """
Samples the elements randomly. Samples the elements randomly.
@ -225,6 +228,7 @@ class RandomSampler(BuiltinSampler):
self.replacement = replacement self.replacement = replacement
self.num_samples = num_samples self.num_samples = num_samples
super().__init__()
def create(self): def create(self):
# If num_samples is not specified, then call constructor #2 # If num_samples is not specified, then call constructor #2
@ -275,6 +279,7 @@ class SubsetRandomSampler(BuiltinSampler):
indices = [indices] indices = [indices]
self.indices = indices self.indices = indices
super().__init__()
def create(self): def create(self):
return cde.SubsetRandomSampler(self.indices) return cde.SubsetRandomSampler(self.indices)
@ -322,6 +327,7 @@ class WeightedRandomSampler(BuiltinSampler):
self.weights = weights self.weights = weights
self.num_samples = num_samples self.num_samples = num_samples
self.replacement = replacement self.replacement = replacement
super().__init__()
def create(self): def create(self):
return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement) return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)