add cpu st lenet

This commit is contained in:
kswang 2020-03-31 21:25:48 +08:00
parent e2df848597
commit 2dc9f632c1
1 changed files with 34 additions and 33 deletions

View File

@ -12,25 +12,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Function:
test network
Usage:
python test_network_main.py --net lenet --target Davinci
"""
import os
import time
import pytest
import numpy as np
import argparse
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
import mindspore.context as context
from mindspore.nn.optim import Momentum
from models.lenet import LeNet
from models.resnetv1_5 import resnet50
from models.alexnet import AlexNet
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
class LeNet(nn.Cell):
def __init__(self):
super(LeNet, self).__init__()
self.relu = P.ReLU()
self.batch_size = 32
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.fc1 = nn.Dense(400, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
def construct(self, input_x):
output = self.conv1(input_x)
output = self.relu(output)
output = self.pool(output)
output = self.conv2(output)
output = self.relu(output)
output = self.pool(output)
output = self.reshape(output, (self.batch_size, -1))
output = self.fc1(output)
output = self.relu(output)
output = self.fc2(output)
output = self.relu(output)
output = self.fc3(output)
return output
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
def train(net, data, label):
@ -48,15 +67,6 @@ def train(net, data, label):
print("+++++++++++++++++++++++++++")
assert res
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_resnet50():
data = Tensor(np.ones([32, 3 ,224, 224]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32))
net = resnet50(32, 10)
train(net, data, label)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@ -65,12 +75,3 @@ def test_lenet():
label = Tensor(np.ones([32]).astype(np.int32))
net = LeNet()
train(net, data, label)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_alexnet():
data = Tensor(np.ones([32, 3 ,227, 227]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32))
net = AlexNet()
train(net, data, label)