diff --git a/MindSpore/src/step3/train_net.py b/MindSpore/src/step3/train_net.py new file mode 100644 index 0000000..f317a34 --- /dev/null +++ b/MindSpore/src/step3/train_net.py @@ -0,0 +1,193 @@ +import os +import urllib.request +from urllib.parse import urlparse +import gzip +import argparse +import mindspore.dataset as ds +import mindspore.nn as nn +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor +from mindspore.train import Model +from mindspore.common.initializer import TruncatedNormal +import mindspore.dataset.transforms.vision.c_transforms as CV +import mindspore.dataset.transforms.c_transforms as C +from mindspore.dataset.transforms.vision import Inter +from mindspore.nn.metrics import Accuracy +from mindspore.common import dtype as mstype +from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits + + +def unzipfile(gzip_path): + """unzip dataset file + Args: + gzip_path: dataset file path + """ + open_file = open(gzip_path.replace('.gz',''), 'wb') + gz_file = gzip.GzipFile(gzip_path) + open_file.write(gz_file.read()) + gz_file.close() + + +def download_dataset(): + """Download the dataset from http://yann.lecun.com/exdb/mnist/.""" + print("******Downloading the MNIST dataset******") + train_path = "./MNIST_Data/train/" + test_path = "./MNIST_Data/test/" + train_path_check = os.path.exists(train_path) + test_path_check = os.path.exists(test_path) + if train_path_check == False and test_path_check ==False: + os.makedirs(train_path) + os.makedirs(test_path) + train_url = {"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"} + test_url = {"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"} + for url in train_url: + url_parse = urlparse(url) + # split the file name from url + file_name = os.path.join(train_path,url_parse.path.split('/')[-1]) + if not os.path.exists(file_name.replace('.gz','')): + file = urllib.request.urlretrieve(url, file_name) + unzipfile(file_name) + os.remove(file_name) + for url in test_url: + url_parse = urlparse(url) + # split the file name from url + file_name = os.path.join(test_path,url_parse.path.split('/')[-1]) + if not os.path.exists(file_name.replace('.gz','')): + file = urllib.request.urlretrieve(url, file_name) + unzipfile(file_name) + os.remove(file_name) + + +def create_dataset(data_path, batch_size=32, repeat_size=1, + num_parallel_workers=1): + """ create dataset for train or test + Args: + data_path: Data path + batch_size: The number of data records in each group + repeat_size: The number of replicated data records + num_parallel_workers: The number of parallel workers + """ + # define dataset + mnist_ds = ds.MnistDataset(data_path) + + # define operation parameters + resize_height, resize_width = 32, 32 + rescale = 1.0 / 255.0 + shift = 0.0 + rescale_nml = 1 / 0.3081 + shift_nml = -1 * 0.1307 / 0.3081 + + # define map operations + resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Resize images to (32, 32) + rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images + rescale_op = CV.Rescale(rescale, shift) # rescale images + hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network. + type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network + + # apply map operations on images + mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) + + # apply DatasetOps + buffer_size = 10000 + mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script + mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) + mnist_ds = mnist_ds.repeat(repeat_size) + + return mnist_ds + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """Conv layer weight initial.""" + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + + +def fc_with_initialize(input_channels, out_channels): + """Fc layer weight initial.""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +def weight_variable(): + """Weight initial.""" + return TruncatedNormal(0.02) + + +class LeNet5(nn.Cell): + """Lenet network structure.""" + # define the operator required + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, 10) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + # use the preceding operators to construct networks + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): + """Define the training method.""" + print("============== Starting Training ==============") + # load training dataset + # 请在此添加代码完成本关任务 + # **********Begin*********# + ##提示:完成网络的配置 + + # **********End**********# + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='MindSpore LeNet Example') + parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'], + help='device where the code will be implemented (default: CPU)') + args = parser.parse_args() + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + # download mnist dataset + download_dataset() + # learning rate setting + lr = 0.01 + momentum = 0.9 + epoch_size = 1 + mnist_path = "./MNIST_Data" + # define the loss function + net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + repeat_size = epoch_size + # create the network + network = LeNet5() + # define the optimizer + net_opt = nn.Momentum(network.trainable_params(), lr, momentum) + # 请在此添加代码完成本关任务 + # **********Begin*********# + ##提示:配置模型保存 + + # **********End**********# + # group layers into an object with training and evaluation features + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + + train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb)