clean pylint warnings of parallel test cases

This commit is contained in:
Yi Huaijie 2020-05-25 10:33:54 +08:00
parent aeffccb7f8
commit 8cfc05e4cf
14 changed files with 53 additions and 65 deletions

View File

@ -72,7 +72,7 @@ class DataGenerator():
i = 0
for stra in strategy:
temp = []
while len(blocks) > 0:
while blocks:
block = blocks.pop(0)
temp.extend(np.split(block, stra, axis=i))
blocks.extend(temp)

View File

@ -63,7 +63,7 @@ class DataGenerator():
i = 0
for stra in strategy:
temp = []
while len(blocks) > 0:
while blocks:
block = blocks.pop(0)
temp.extend(np.split(block, stra, axis=i))
blocks.extend(temp)

View File

@ -172,10 +172,12 @@ class ResNet(nn.Cell):
layer_nums,
in_channels,
out_channels,
strides=[1, 2, 2, 2],
strides=None,
num_classes=100):
super(ResNet, self).__init__()
if strides is None:
strides = [1, 2, 2, 2]
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of "
"layer_num, inchannel, outchannel list must be 4!")
@ -300,7 +302,7 @@ class DataGenerator():
i = 0
for stra in strategy:
temp = []
while len(blocks) > 0:
while blocks:
block = blocks.pop(0)
temp.extend(np.split(block, stra, axis=i))
blocks.extend(temp)

View File

@ -38,17 +38,6 @@ class NetWithLoss(nn.Cell):
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y, z, w):
return C.grad_all(self.network)(x, y, z, w)
# model_parallel test
def test_common_parameter():
class Net(nn.Cell):
def __init__(self):

View File

@ -174,9 +174,9 @@ def test_reshape_auto_4():
def test_reshape_auto_5():
class NetWithLoss(nn.Cell):
class NetWithLoss5(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
super(NetWithLoss5, self).__init__()
self.loss = VirtualLoss()
self.network = network
@ -184,9 +184,9 @@ def test_reshape_auto_5():
predict = self.network(x, y)
return self.loss(predict)
class GradWrap(nn.Cell):
class GradWrap5(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
super(GradWrap5, self).__init__()
self.network = network
def construct(self, x, y):
@ -217,16 +217,16 @@ def test_reshape_auto_5():
x = Tensor(np.ones([4, 1024 * size, 1]), dtype=ms.float32)
y = Tensor(np.ones([4, 1024 * size,]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
net = GradWrap5(NetWithLoss5(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x, y)
def test_reshape_auto_6():
class NetWithLoss(nn.Cell):
class NetWithLoss6(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
super(NetWithLoss6, self).__init__()
self.loss = VirtualLoss()
self.network = network
@ -234,9 +234,9 @@ def test_reshape_auto_6():
predict = self.network(x, y)
return self.loss(predict)
class GradWrap(nn.Cell):
class GradWrap6(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
super(GradWrap6, self).__init__()
self.network = network
def construct(self, x, y):
@ -265,7 +265,7 @@ def test_reshape_auto_6():
x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32)
y = Tensor(np.ones([4, 1024,]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
net = GradWrap6(NetWithLoss6(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x, y)

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import re
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
@ -36,35 +36,33 @@ context.set_context(device_id=0)
init()
def weight_variable(shape, factor=0.1):
def weight_variable():
return TruncatedNormal(0.02)
def _conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
"""Get a conv2d layer with 3x3 kernel size."""
init_value = weight_variable((out_channels, in_channels, 3, 3))
init_value = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
"""Get a conv2d layer with 1x1 kernel size."""
init_value = weight_variable((out_channels, in_channels, 1, 1))
init_value = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
def _conv7x7(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
"""Get a conv2d layer with 7x7 kernel size."""
init_value = weight_variable((out_channels, in_channels, 7, 7))
init_value = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=7, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
def _fused_bn(channels, momentum=0.9):
"""Get a fused batchnorm"""
init_weight = weight_variable((channels,))
init_bias = weight_variable((channels,))
return nn.BatchNorm2d(channels, momentum=momentum)
@ -132,10 +130,11 @@ class ResNet(nn.Cell):
layer_nums,
in_channels,
out_channels,
strides=[1, 2, 2, 2],
strides=None,
num_classes=100):
super(ResNet, self).__init__()
if strides is None:
strides = [1, 2, 2, 2]
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of "
"layer_num, inchannel, outchannel list must be 4!")
@ -168,16 +167,13 @@ class ResNet(nn.Cell):
self.mean = P.ReduceMean(keep_dims=True)
self.end_point = nn.Dense(2048, num_classes, has_bias=True,
weight_init=weight_variable((num_classes, 2048)),
bias_init=weight_variable((num_classes,))).add_flags_recursive(fp16=True)
weight_init=weight_variable(),
bias_init=weight_variable()).add_flags_recursive(fp16=True)
self.squeeze = P.Squeeze()
self.cast = P.Cast()
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
layers = []
down_sample = False
if stride != 1 or in_channel != out_channel:
down_sample = True
resblk = block(in_channel, out_channel, stride=1)
layers.append(resblk)
@ -279,7 +275,7 @@ class DatasetLenet():
return 1
def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768):
def test_train_32k_8p(batch_size=32, num_classes=32768):
dev_num = 8
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
set_algo_parameters(elementwise_op_strategy_follow=True)
@ -309,12 +305,12 @@ def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768):
return allreduce_fusion_dict
def train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): # 1048576 #131072 #32768 #8192
def train_32k_8p_fusion1(batch_size=32, num_classes=32768): # 1048576 #131072 #32768 #8192
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
allreduce_fusion_dict = test_train_32k_8p(epoch_size, batch_size, num_classes)
allreduce_fusion_dict = test_train_32k_8p(batch_size, num_classes)
expect_dict = {'end_point.bias': 2,
'end_point.weight': 2,
'layer4.2.bn3.beta': 2,
@ -477,17 +473,17 @@ def train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): # 104
'bn1.gamma': 1,
'conv1.weight': 1}
assert (allreduce_fusion_dict == expect_dict)
assert allreduce_fusion_dict == expect_dict
cost_model_context.reset_cost_model_context()
def train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): # 1048576 #131072 #32768 #8192
def train_32k_8p_fusion2(batch_size=32, num_classes=32768): # 1048576 #131072 #32768 #8192
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.05)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.000001)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.0000015)
allreduce_fusion_dict = test_train_32k_8p(epoch_size, batch_size, num_classes)
allreduce_fusion_dict = test_train_32k_8p(batch_size, num_classes)
expect_dict = {'end_point.bias': 2,
'end_point.weight': 2,
'layer4.2.bn3.beta': 2,
@ -650,11 +646,11 @@ def train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): # 104
'bn1.gamma': 1,
'conv1.weight': 1}
assert (allreduce_fusion_dict == expect_dict)
assert allreduce_fusion_dict == expect_dict
cost_model_context.reset_cost_model_context()
def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): # 1048576 #131072 #32768 #8192
def test_train_64k_8p(batch_size=32, num_classes=65536): # 1048576 #131072 #32768 #8192
dev_num = 8
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)

View File

@ -58,7 +58,7 @@ def test_zig_zag_graph():
def construct(self, x, y, z, w, a):
m1_result = self.matmul1(x, y)
m2_result = self.matmul2(z, w)
m3_result = self.matmul3(m2_result, m1_result)
_ = self.matmul3(m2_result, m1_result)
out = self.matmul4(m2_result, a)
return out

View File

@ -101,7 +101,7 @@ def fixme_test_dataset_interface_sens_scalar():
class TrainOneStepCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0):
def __init__(self, network, optimizer):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
@ -135,7 +135,7 @@ def test_dataset_interface_sens_shape_not_equal_loss():
sens = Tensor(np.ones([256, 1024]), dtype=ms.float32)
try:
loss_scale_manager_sens(strategy1, sens)
except:
except BaseException:
pass

View File

@ -45,8 +45,10 @@ class GradWrap(nn.Cell):
class Net(nn.Cell):
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=[64, 64]):
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None):
super().__init__()
if shape is None:
shape = [64, 64]
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
self.mul = P.Mul().set_strategy(strategy2)
self.index = Tensor(np.ones(shape), dtype=ms.int32)

View File

@ -221,14 +221,14 @@ def test_axis1_auto_batch_parallel():
def test_axis1_batch_parallel():
gather_v2_strategy = ((device_number, 1), (1, ))
gather_v2_strategy = ((device_number, 1), (1,))
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
rank = 2
net_trains(criterion, rank)
def test_axis1_strategy1():
gather_v2_strategy = ((16, 2), (1, ))
gather_v2_strategy = ((16, 2), (1,))
rank = 17
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
net_trains(criterion, rank)

View File

@ -265,7 +265,6 @@ class BNReshapeDenseBNNet(nn.Cell):
def test_bn_reshape_dense_bn_train_loss():
batch_size = 16
device_num = 16
context.set_auto_parallel_context(device_num=device_num, global_rank=0)
input_ = Tensor(np.ones([batch_size, 2, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]), dtype=ms.int32)

View File

@ -104,7 +104,7 @@ def test_onehot_batch_parallel_invalid_strategy():
strategy4 = ((16, 1), (16, 1))
try:
compile_graph(strategy1, strategy2, strategy3, strategy4)
except:
except BaseException:
pass
@ -144,7 +144,7 @@ def test_onehot_batch_parallel_invalid_strategy_axis0():
strategy4 = ((16, 1), (16, 1))
try:
compile_graph(strategy1, strategy2, strategy3, strategy4, onthot_axis=0)
except:
except BaseException:
pass

View File

@ -124,9 +124,9 @@ def test_prelu_parallel_success2():
def test_prelu_parallel_success3():
class NetWithLoss(nn.Cell):
class NetWithLoss3(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
super(NetWithLoss3, self).__init__()
self.loss = VirtualLoss()
self.network = network
@ -134,9 +134,9 @@ def test_prelu_parallel_success3():
predict = self.network(x, y, w)
return self.loss(predict)
class GradWrap(nn.Cell):
class GradWrap3(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
super(GradWrap3, self).__init__()
self.network = network
def construct(self, x, y, w):
@ -161,7 +161,7 @@ def test_prelu_parallel_success3():
x = Tensor(np.random.rand(128, 64), dtype=ms.float32)
y = Tensor(np.random.rand(64, 16), dtype=ms.float32)
w = Tensor(np.random.rand(16), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
net = GradWrap3(NetWithLoss3(Net(strategy1, strategy2)))
net.set_auto_parallel()
_executor.compile(net, x, y, w)

View File

@ -114,7 +114,7 @@ def test_reshape1_strategy_1():
strategy_loss = ((8, 1), (8, 1))
try:
reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
except:
except BaseException:
pass
@ -125,7 +125,7 @@ def test_reshape1_strategy_2():
strategy_loss = ((8, 1), (8, 1))
try:
reshape_common(ParallelMode.AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
except:
except BaseException:
pass
@ -347,14 +347,14 @@ def test_reshape_net3_2():
def test_reshape_net4_1():
try:
reshape_net2(ReshapeNet4(((1, 8), (8, 1))))
except:
except BaseException:
pass
def test_reshape_net4_2():
try:
reshape_net2(ReshapeNet4(((1, 8), (8, 2))))
except:
except BaseException:
pass