forked from mindspore-Ecosystem/mindspore
!995 Clean some pylint-warnings
Merge pull request !995 from SJN/master
This commit is contained in:
commit
2c4fec57a0
|
@ -41,23 +41,22 @@ from config import ConfigYOLOV3ResNet18
|
|||
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
|
||||
"""Set learning rate."""
|
||||
lr_each_step = []
|
||||
lr = learning_rate
|
||||
for i in range(global_step):
|
||||
if steps:
|
||||
lr_each_step.append(lr * (decay_rate ** (i // decay_step)))
|
||||
lr_each_step.append(learning_rate * (decay_rate ** (i // decay_step)))
|
||||
else:
|
||||
lr_each_step.append(lr * (decay_rate ** (i / decay_step)))
|
||||
lr_each_step.append(learning_rate * (decay_rate ** (i / decay_step)))
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
lr_each_step = lr_each_step[start_step:]
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def init_net_param(net, init='ones'):
|
||||
"""Init the parameters in net."""
|
||||
params = net.trainable_params()
|
||||
def init_net_param(network, init_value='ones'):
|
||||
"""Init:wq the parameters in network."""
|
||||
params = network.trainable_params()
|
||||
for p in params:
|
||||
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
||||
p.set_parameter_data(initializer(init, p.data.shape(), p.data.dtype()))
|
||||
p.set_parameter_data(initializer(init_value, p.data.shape(), p.data.dtype()))
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
@ -33,6 +33,7 @@ class AKGMetaPathFinder:
|
|||
|
||||
def find_module(self, fullname, path=None):
|
||||
"""method _akg find module."""
|
||||
_ = path
|
||||
if fullname.startswith("_akg.tvm"):
|
||||
rname = fullname[5:]
|
||||
return AKGMetaPathLoader(rname)
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
"""format transform function"""
|
||||
import _akg
|
||||
|
||||
def refine_reduce_axis(input, axis):
|
||||
def refine_reduce_axis(input_content, axis):
|
||||
"""make reduce axis legal."""
|
||||
shape = get_shape(input)
|
||||
shape = get_shape(input_content)
|
||||
if axis is None:
|
||||
axis = [i for i in range(len(shape))]
|
||||
elif isinstance(axis, int):
|
||||
|
|
Loading…
Reference in New Issue