forked from mindspore-Ecosystem/mindspore
modify longtime python ut
This commit is contained in:
parent
d84bf8d357
commit
eaf7146d46
|
@ -159,7 +159,7 @@ class Conv2d(_Conv):
|
||||||
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
|
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
|
||||||
>>> input = mindspore.Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
>>> input = mindspore.Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||||
>>> net(input).shape()
|
>>> net(input).shape()
|
||||||
(1, 240, 1024, 637)
|
(1, 240, 1024, 640)
|
||||||
"""
|
"""
|
||||||
@cell_attr_register
|
@cell_attr_register
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
|
@ -49,16 +49,16 @@ def test_reshape_matmul():
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.matmul = P.MatMul()
|
self.matmul = P.MatMul()
|
||||||
self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
|
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
out = self.reshape(x, (256, 25088))
|
out = self.reshape(x, (64, 28))
|
||||||
out = self.matmul(out, self.matmul_weight)
|
out = self.matmul(out, self.matmul_weight)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
size = 8
|
size = 8
|
||||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||||
x = Tensor(np.ones([32*size, 512, 7, 7]), dtype=ms.float32)
|
x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32)
|
||||||
|
|
||||||
net = GradWrap(NetWithLoss(Net()))
|
net = GradWrap(NetWithLoss(Net()))
|
||||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||||
|
|
|
@ -247,15 +247,15 @@ def fc_with_initialize(input_channels, out_channels):
|
||||||
class BNReshapeDenseBNNet(nn.Cell):
|
class BNReshapeDenseBNNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(BNReshapeDenseBNNet, self).__init__()
|
super(BNReshapeDenseBNNet, self).__init__()
|
||||||
self.batch_norm = bn_with_initialize(512)
|
self.batch_norm = bn_with_initialize(2)
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.batch_norm2 = nn.BatchNorm1d(512, affine=False)
|
self.batch_norm2 = nn.BatchNorm1d(512, affine=False)
|
||||||
self.fc = fc_with_initialize(512 * 32 * 32, 512)
|
self.fc = fc_with_initialize(2 * 32 * 32, 512)
|
||||||
self.loss = SemiAutoOneHotNet(args=Args(), strategy=StrategyBatch())
|
self.loss = SemiAutoOneHotNet(args=Args(), strategy=StrategyBatch())
|
||||||
|
|
||||||
def construct(self, x, label):
|
def construct(self, x, label):
|
||||||
x = self.batch_norm(x)
|
x = self.batch_norm(x)
|
||||||
x = self.reshape(x, (16, 512*32*32))
|
x = self.reshape(x, (16, 2*32*32))
|
||||||
x = self.fc(x)
|
x = self.fc(x)
|
||||||
x = self.batch_norm2(x)
|
x = self.batch_norm2(x)
|
||||||
loss = self.loss(x, label)
|
loss = self.loss(x, label)
|
||||||
|
@ -266,7 +266,7 @@ def test_bn_reshape_dense_bn_train_loss():
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
device_num = 16
|
device_num = 16
|
||||||
context.set_auto_parallel_context(device_num=device_num, global_rank=0)
|
context.set_auto_parallel_context(device_num=device_num, global_rank=0)
|
||||||
input = Tensor(np.ones([batch_size, 512, 32, 32]).astype(np.float32) * 0.01)
|
input = Tensor(np.ones([batch_size, 2, 32, 32]).astype(np.float32) * 0.01)
|
||||||
label = Tensor(np.ones([batch_size]), dtype=ms.int32)
|
label = Tensor(np.ones([batch_size]), dtype=ms.int32)
|
||||||
|
|
||||||
net = GradWrap(NetWithLoss(BNReshapeDenseBNNet()))
|
net = GradWrap(NetWithLoss(BNReshapeDenseBNNet()))
|
||||||
|
|
|
@ -490,15 +490,15 @@ def fc_with_initialize(input_channels, out_channels):
|
||||||
class BNReshapeDenseBNNet(nn.Cell):
|
class BNReshapeDenseBNNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(BNReshapeDenseBNNet, self).__init__()
|
super(BNReshapeDenseBNNet, self).__init__()
|
||||||
self.batch_norm = bn_with_initialize(512)
|
self.batch_norm = bn_with_initialize(2)
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.batch_norm2 = nn.BatchNorm1d(512, affine=False)
|
self.batch_norm2 = nn.BatchNorm1d(512, affine=False)
|
||||||
self.fc = fc_with_initialize(512 * 32 * 32, 512)
|
self.fc = fc_with_initialize(2 * 32 * 32, 512)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x = self.batch_norm(x)
|
x = self.batch_norm(x)
|
||||||
x = self.reshape(x, (16, 512*32*32))
|
x = self.reshape(x, (16, 2*32*32))
|
||||||
x = self.fc(x)
|
x = self.fc(x)
|
||||||
x = self.batch_norm2(x)
|
x = self.batch_norm2(x)
|
||||||
return x
|
return x
|
||||||
|
@ -508,7 +508,7 @@ def test_bn_reshape_dense_bn_train():
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
device_num = 16
|
device_num = 16
|
||||||
context.set_auto_parallel_context(device_num=device_num, global_rank=0)
|
context.set_auto_parallel_context(device_num=device_num, global_rank=0)
|
||||||
input = Tensor(np.ones([batch_size, 512, 32, 32]).astype(np.float32) * 0.01)
|
input = Tensor(np.ones([batch_size, 2, 32, 32]).astype(np.float32) * 0.01)
|
||||||
|
|
||||||
net = GradWrap(NetWithLoss(BNReshapeDenseBNNet()))
|
net = GradWrap(NetWithLoss(BNReshapeDenseBNNet()))
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||||
|
|
|
@ -43,9 +43,9 @@ def get_test_data(step):
|
||||||
tag1 = "xt1[:Tensor]"
|
tag1 = "xt1[:Tensor]"
|
||||||
tag2 = "xt2[:Tensor]"
|
tag2 = "xt2[:Tensor]"
|
||||||
tag3 = "xt3[:Tensor]"
|
tag3 = "xt3[:Tensor]"
|
||||||
np1 = np.random.random((50, 40, 30, 50))
|
np1 = np.random.random((5, 4, 3, 5))
|
||||||
np2 = np.random.random((50, 50, 30, 50))
|
np2 = np.random.random((5, 5, 3, 5))
|
||||||
np3 = np.random.random((40, 55, 30, 50))
|
np3 = np.random.random((4, 5, 3, 5))
|
||||||
|
|
||||||
dict1 = {}
|
dict1 = {}
|
||||||
dict1["name"] = tag1
|
dict1["name"] = tag1
|
||||||
|
|
Loading…
Reference in New Issue