第二关代码

This commit is contained in:
马天姿 2020-07-08 02:42:20 +08:00
parent cef954e0d0
commit dc7503c488
1 changed files with 35 additions and 0 deletions

View File

@ -0,0 +1,35 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
# download mnist dataset
download_dataset()
# learning rate setting
lr = 0.01
momentum = 0.9
epoch_size = 1
mnist_path = "./MNIST_Data"
# 请在此添加代码完成本关任务
# **********Begin*********#
##提示:补全损失函数的定义
# **********End**********#
repeat_size = epoch_size
# create the network
network = LeNet5()
# define the optimizer
# 请在此添加代码完成本关任务
# **********Begin*********#
##提示:补全优化器的定义
# **********End**********#
# save the network model and parameters for subsequence fine-tuning
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
# group layers into an object with training and evaluation features
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb)
test_net(args, network, model, mnist_path)