User submitted
This commit is contained in:
parent
d1e8799462
commit
674b699f14
|
@ -19,8 +19,6 @@ from mindspore.dataset.transforms.vision import Inter
|
|||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
|
||||
|
||||
def unzipfile(gzip_path):
|
||||
"""unzip dataset file
|
||||
Args:
|
||||
|
@ -30,8 +28,6 @@ def unzipfile(gzip_path):
|
|||
gz_file = gzip.GzipFile(gzip_path)
|
||||
open_file.write(gz_file.read())
|
||||
gz_file.close()
|
||||
|
||||
|
||||
def download_dataset():
|
||||
"""Download the dataset from http://yann.lecun.com/exdb/mnist/."""
|
||||
print("******Downloading the MNIST dataset******")
|
||||
|
@ -60,8 +56,6 @@ def download_dataset():
|
|||
file = urllib.request.urlretrieve(url, file_name)
|
||||
unzipfile(file_name)
|
||||
os.remove(file_name)
|
||||
|
||||
|
||||
def create_dataset(data_path, batch_size=32, repeat_size=1,
|
||||
num_parallel_workers=1):
|
||||
""" create dataset for train or test
|
||||
|
@ -73,57 +67,44 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
|
|||
"""
|
||||
# define dataset
|
||||
mnist_ds = ds.MnistDataset(data_path)
|
||||
|
||||
# define operation parameters
|
||||
resize_height, resize_width = 32, 32
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
rescale_nml = 1 / 0.3081
|
||||
shift_nml = -1 * 0.1307 / 0.3081
|
||||
|
||||
# define map operations
|
||||
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Resize images to (32, 32)
|
||||
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images
|
||||
rescale_op = CV.Rescale(rescale, shift) # rescale images
|
||||
hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.
|
||||
type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network
|
||||
|
||||
# apply map operations on images
|
||||
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
|
||||
|
||||
# apply DatasetOps
|
||||
buffer_size = 10000
|
||||
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
|
||||
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
|
||||
mnist_ds = mnist_ds.repeat(repeat_size)
|
||||
|
||||
return mnist_ds
|
||||
|
||||
|
||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
"""Conv layer weight initial."""
|
||||
weight = weight_variable()
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
weight_init=weight, has_bias=False, pad_mode="valid")
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
"""Fc layer weight initial."""
|
||||
weight = weight_variable()
|
||||
bias = weight_variable()
|
||||
return nn.Dense(input_channels, out_channels, weight, bias)
|
||||
|
||||
|
||||
def weight_variable():
|
||||
"""Weight initial."""
|
||||
return TruncatedNormal(0.02)
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
"""Lenet network structure."""
|
||||
# define the operator required
|
||||
|
@ -137,7 +118,6 @@ class LeNet5(nn.Cell):
|
|||
self.relu = nn.ReLU()
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
# use the preceding operators to construct networks
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
|
@ -153,16 +133,12 @@ class LeNet5(nn.Cell):
|
|||
x = self.relu(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb):
|
||||
"""Define the training method."""
|
||||
print("============== Starting Training ==============")
|
||||
# load training dataset
|
||||
ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)
|
||||
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_net(args, network, model, mnist_path):
|
||||
"""Define the evaluation method."""
|
||||
print("============== Starting Testing ==============")
|
||||
|
@ -170,11 +146,14 @@ def test_net(args, network, model, mnist_path):
|
|||
# 请在此添加代码完成本关任务
|
||||
#********** Begin *********#
|
||||
## 提示:补全验证函数的代码
|
||||
|
||||
param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
|
||||
# load parameter to the network
|
||||
load_param_into_net(network, param_dict)
|
||||
# load testing dataset
|
||||
ds_eval = create_dataset(os.path.join(mnist_path, "test"))
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
#********** End **********#
|
||||
print("============== Accuracy:{} ==============".format(acc))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
|
||||
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
|
||||
|
@ -200,6 +179,5 @@ if __name__ == "__main__":
|
|||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
||||
# group layers into an object with training and evaluation features
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb)
|
||||
test_net(args, network, model, mnist_path)
|
||||
test_net(args, network, model, mnist_path)
|
Loading…
Reference in New Issue