test_net
This commit is contained in:
parent
f99e4df579
commit
e8945f2b00
|
@ -0,0 +1,205 @@
|
||||||
|
"""Lenet Tutorial
|
||||||
|
The sample can be run on CPU, GPU and Ascend 910 AI processor.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import urllib.request
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import gzip
|
||||||
|
import argparse
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||||
|
from mindspore.train import Model
|
||||||
|
from mindspore.common.initializer import TruncatedNormal
|
||||||
|
import mindspore.dataset.transforms.vision.c_transforms as CV
|
||||||
|
import mindspore.dataset.transforms.c_transforms as C
|
||||||
|
from mindspore.dataset.transforms.vision import Inter
|
||||||
|
from mindspore.nn.metrics import Accuracy
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||||
|
|
||||||
|
|
||||||
|
def unzipfile(gzip_path):
|
||||||
|
"""unzip dataset file
|
||||||
|
Args:
|
||||||
|
gzip_path: dataset file path
|
||||||
|
"""
|
||||||
|
open_file = open(gzip_path.replace('.gz',''), 'wb')
|
||||||
|
gz_file = gzip.GzipFile(gzip_path)
|
||||||
|
open_file.write(gz_file.read())
|
||||||
|
gz_file.close()
|
||||||
|
|
||||||
|
|
||||||
|
def download_dataset():
|
||||||
|
"""Download the dataset from http://yann.lecun.com/exdb/mnist/."""
|
||||||
|
print("******Downloading the MNIST dataset******")
|
||||||
|
train_path = "./MNIST_Data/train/"
|
||||||
|
test_path = "./MNIST_Data/test/"
|
||||||
|
train_path_check = os.path.exists(train_path)
|
||||||
|
test_path_check = os.path.exists(test_path)
|
||||||
|
if train_path_check == False and test_path_check ==False:
|
||||||
|
os.makedirs(train_path)
|
||||||
|
os.makedirs(test_path)
|
||||||
|
train_url = {"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"}
|
||||||
|
test_url = {"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"}
|
||||||
|
for url in train_url:
|
||||||
|
url_parse = urlparse(url)
|
||||||
|
# split the file name from url
|
||||||
|
file_name = os.path.join(train_path,url_parse.path.split('/')[-1])
|
||||||
|
if not os.path.exists(file_name.replace('.gz','')):
|
||||||
|
file = urllib.request.urlretrieve(url, file_name)
|
||||||
|
unzipfile(file_name)
|
||||||
|
os.remove(file_name)
|
||||||
|
for url in test_url:
|
||||||
|
url_parse = urlparse(url)
|
||||||
|
# split the file name from url
|
||||||
|
file_name = os.path.join(test_path,url_parse.path.split('/')[-1])
|
||||||
|
if not os.path.exists(file_name.replace('.gz','')):
|
||||||
|
file = urllib.request.urlretrieve(url, file_name)
|
||||||
|
unzipfile(file_name)
|
||||||
|
os.remove(file_name)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(data_path, batch_size=32, repeat_size=1,
|
||||||
|
num_parallel_workers=1):
|
||||||
|
""" create dataset for train or test
|
||||||
|
Args:
|
||||||
|
data_path: Data path
|
||||||
|
batch_size: The number of data records in each group
|
||||||
|
repeat_size: The number of replicated data records
|
||||||
|
num_parallel_workers: The number of parallel workers
|
||||||
|
"""
|
||||||
|
# define dataset
|
||||||
|
mnist_ds = ds.MnistDataset(data_path)
|
||||||
|
|
||||||
|
# define operation parameters
|
||||||
|
resize_height, resize_width = 32, 32
|
||||||
|
rescale = 1.0 / 255.0
|
||||||
|
shift = 0.0
|
||||||
|
rescale_nml = 1 / 0.3081
|
||||||
|
shift_nml = -1 * 0.1307 / 0.3081
|
||||||
|
|
||||||
|
# define map operations
|
||||||
|
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Resize images to (32, 32)
|
||||||
|
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images
|
||||||
|
rescale_op = CV.Rescale(rescale, shift) # rescale images
|
||||||
|
hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.
|
||||||
|
type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network
|
||||||
|
|
||||||
|
# apply map operations on images
|
||||||
|
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
|
||||||
|
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
|
||||||
|
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
|
||||||
|
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
|
||||||
|
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
|
||||||
|
|
||||||
|
# apply DatasetOps
|
||||||
|
buffer_size = 10000
|
||||||
|
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
|
||||||
|
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
|
||||||
|
mnist_ds = mnist_ds.repeat(repeat_size)
|
||||||
|
|
||||||
|
return mnist_ds
|
||||||
|
|
||||||
|
|
||||||
|
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||||
|
"""Conv layer weight initial."""
|
||||||
|
weight = weight_variable()
|
||||||
|
return nn.Conv2d(in_channels, out_channels,
|
||||||
|
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
weight_init=weight, has_bias=False, pad_mode="valid")
|
||||||
|
|
||||||
|
|
||||||
|
def fc_with_initialize(input_channels, out_channels):
|
||||||
|
"""Fc layer weight initial."""
|
||||||
|
weight = weight_variable()
|
||||||
|
bias = weight_variable()
|
||||||
|
return nn.Dense(input_channels, out_channels, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def weight_variable():
|
||||||
|
"""Weight initial."""
|
||||||
|
return TruncatedNormal(0.02)
|
||||||
|
|
||||||
|
|
||||||
|
class LeNet5(nn.Cell):
|
||||||
|
"""Lenet network structure."""
|
||||||
|
# define the operator required
|
||||||
|
def __init__(self):
|
||||||
|
super(LeNet5, self).__init__()
|
||||||
|
self.conv1 = conv(1, 6, 5)
|
||||||
|
self.conv2 = conv(6, 16, 5)
|
||||||
|
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
||||||
|
self.fc2 = fc_with_initialize(120, 84)
|
||||||
|
self.fc3 = fc_with_initialize(84, 10)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||||
|
self.flatten = nn.Flatten()
|
||||||
|
|
||||||
|
# use the preceding operators to construct networks
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.max_pool2d(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.max_pool2d(x)
|
||||||
|
x = self.flatten(x)
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.fc3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb):
|
||||||
|
"""Define the training method."""
|
||||||
|
print("============== Starting Training ==============")
|
||||||
|
# load training dataset
|
||||||
|
ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)
|
||||||
|
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net(args, network, model, mnist_path):
|
||||||
|
"""Define the evaluation method."""
|
||||||
|
print("============== Starting Testing ==============")
|
||||||
|
# load the saved model for evaluation
|
||||||
|
# 请在此添加代码完成本关任务
|
||||||
|
# **********Begin*********#
|
||||||
|
##提示:补全验证函数的代码
|
||||||
|
|
||||||
|
# **********End**********#
|
||||||
|
print("============== Accuracy:{} ==============".format(acc))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
|
||||||
|
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
|
||||||
|
help='device where the code will be implemented (default: CPU)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
# download mnist dataset
|
||||||
|
download_dataset()
|
||||||
|
# learning rate setting
|
||||||
|
lr = 0.01
|
||||||
|
momentum = 0.9
|
||||||
|
epoch_size = 1
|
||||||
|
mnist_path = "./MNIST_Data"
|
||||||
|
# define the loss function
|
||||||
|
net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
|
||||||
|
repeat_size = epoch_size
|
||||||
|
# create the network
|
||||||
|
network = LeNet5()
|
||||||
|
# define the optimizer
|
||||||
|
net_opt = nn.Momentum(network.trainable_params(), lr, momentum)
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
|
||||||
|
# save the network model and parameters for subsequence fine-tuning
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
||||||
|
# group layers into an object with training and evaluation features
|
||||||
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||||
|
|
||||||
|
train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb)
|
||||||
|
test_net(args, network, model, mnist_path)
|
Loading…
Reference in New Issue