modify lenet&alexnet dir

This commit is contained in:
wukesong 2020-05-29 15:08:02 +08:00
parent fce37a5fbe
commit 7dfd369998
17 changed files with 148 additions and 70 deletions

View File

@ -36,10 +36,9 @@ class AlexNet(nn.Cell):
""" """
Alexnet Alexnet
""" """
def __init__(self, num_classes=10): def __init__(self, num_classes=10, channel=3):
super(AlexNet, self).__init__() super(AlexNet, self).__init__()
self.batch_size = 32 self.conv1 = conv(channel, 96, 11, stride=4)
self.conv1 = conv(3, 96, 11, stride=4)
self.conv2 = conv(96, 256, 5, pad_mode="same") self.conv2 = conv(96, 256, 5, pad_mode="same")
self.conv3 = conv(256, 384, 3, pad_mode="same") self.conv3 = conv(256, 384, 3, pad_mode="same")
self.conv4 = conv(384, 384, 3, pad_mode="same") self.conv4 = conv(384, 384, 3, pad_mode="same")

View File

@ -23,7 +23,7 @@ import mindspore.dataset.transforms.vision.c_transforms as CV
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
def create_dataset(data_path, batch_size=32, repeat_size=1, status="train"): def create_dataset_mnist(data_path, batch_size=32, repeat_size=1, status="train"):
""" """
create dataset for train or test create dataset for train or test
""" """

View File

@ -20,10 +20,10 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
import argparse import argparse
from config import alexnet_cfg as cfg from config import alexnet_cfg as cfg
from dataset import create_dataset from dataset import create_dataset_mnist
from alexnet import AlexNet
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.model_zoo.alexnet import AlexNet
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model from mindspore.train import Model
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
@ -50,9 +50,8 @@ if __name__ == "__main__":
print("============== Starting Testing ==============") print("============== Starting Testing ==============")
param_dict = load_checkpoint(args.ckpt_path) param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
ds_eval = create_dataset(args.data_path, ds_eval = create_dataset_mnist(args.data_path,
cfg.batch_size, cfg.batch_size,
1, status="test")
"test")
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("============== Accuracy:{} ==============".format(acc)) print("============== Accuracy:{} ==============".format(acc))

View File

@ -20,14 +20,14 @@ python train.py --data_path /YourDataPath
import argparse import argparse
from config import alexnet_cfg as cfg from config import alexnet_cfg as cfg
from dataset import create_dataset from dataset import create_dataset_mnist
from generator_lr import get_lr from generator_lr import get_lr
from alexnet import AlexNet
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore.train import Model from mindspore.train import Model
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.model_zoo.alexnet import AlexNet
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
@ -50,9 +50,9 @@ if __name__ == "__main__":
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()}) # test model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()}) # test
print("============== Starting Training ==============") print("============== Starting Training ==============")
ds_train = create_dataset(args.data_path, ds_train = create_dataset_mnist(args.data_path,
cfg.batch_size, cfg.batch_size,
cfg.epoch_size) cfg.epoch_size)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max) keep_checkpoint_max=cfg.keep_checkpoint_max)

View File

@ -22,8 +22,8 @@ import os
import argparse import argparse
from dataset import create_dataset from dataset import create_dataset
from config import mnist_cfg as cfg from config import mnist_cfg as cfg
from lenet import LeNet5
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.model_zoo.lenet import LeNet5
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

View File

@ -50,11 +50,10 @@ class LeNet5(nn.Cell):
>>> LeNet(num_class=10) >>> LeNet(num_class=10)
""" """
def __init__(self, num_class=10): def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.num_class = num_class self.num_class = num_class
self.batch_size = 32 self.conv1 = conv(channel, 6, 5)
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5) self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120) self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84) self.fc2 = fc_with_initialize(120, 84)

View File

@ -22,8 +22,8 @@ import os
import argparse import argparse
from config import mnist_cfg as cfg from config import mnist_cfg as cfg
from dataset import create_dataset from dataset import create_dataset
from lenet import LeNet5
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.model_zoo.lenet import LeNet5
from mindspore import context from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model from mindspore.train import Model
@ -36,7 +36,7 @@ if __name__ == "__main__":
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data", parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved') help='path where the dataset is saved')
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True') parser.add_argument('--dataset_sink_mode', type=bool, default=True, help='dataset_sink_mode is False or True')
args = parser.parse_args() args = parser.parse_args()

78
tests/perf_test/lenet.py Normal file
View File

@ -0,0 +1,78 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LeNet."""
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = conv(channel, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x

View File

@ -17,12 +17,12 @@
import numpy as np import numpy as np
from lenet import LeNet5
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops.composite as C import mindspore.ops.composite as C
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.model_zoo.lenet import LeNet
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
@ -61,7 +61,7 @@ def test_compile():
def test_compile_grad(): def test_compile_grad():
"""Compile forward and backward graph""" """Compile forward and backward graph"""
net = LeNet(num_class=num_class) net = LeNet5(num_class=num_class)
inp = Tensor(np.array(np.random.randn(batch_size, inp = Tensor(np.array(np.random.randn(batch_size,
channel, channel,
height, height,

View File

@ -1,46 +0,0 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore.ops import operations as P
class LeNet(nn.Cell):
def __init__(self):
super(LeNet, self).__init__()
self.relu = P.ReLU()
self.batch_size = 32
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.fc1 = nn.Dense(400, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
def construct(self, input_x):
output = self.conv1(input_x)
output = self.relu(output)
output = self.pool(output)
output = self.conv2(output)
output = self.relu(output)
output = self.pool(output)
output = self.reshape(output, (self.batch_size, -1))
output = self.fc1(output)
output = self.relu(output)
output = self.fc2(output)
output = self.relu(output)
output = self.fc3(output)
return output

View File

@ -26,17 +26,66 @@ import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.dataset.transforms.vision import Inter from mindspore.dataset.transforms.vision import Inter
from mindspore.model_zoo.lenet import LeNet5
from mindspore.nn import Dense, TrainOneStepCell, WithLossCell from mindspore.nn import Dense, TrainOneStepCell, WithLossCell
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.nn.optim import Momentum from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train import Model from mindspore.train import Model
from mindspore.train.callback import LossMonitor from mindspore.train.callback import LossMonitor
from mindspore.common.initializer import TruncatedNormal
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = conv(channel, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
class LeNet(nn.Cell): class LeNet(nn.Cell):
def __init__(self): def __init__(self):
super(LeNet, self).__init__() super(LeNet, self).__init__()