第二关代码
This commit is contained in:
parent
cef954e0d0
commit
dc7503c488
|
@ -0,0 +1,35 @@
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
|
||||
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: CPU)')
|
||||
args = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
# download mnist dataset
|
||||
download_dataset()
|
||||
# learning rate setting
|
||||
lr = 0.01
|
||||
momentum = 0.9
|
||||
epoch_size = 1
|
||||
mnist_path = "./MNIST_Data"
|
||||
# 请在此添加代码完成本关任务
|
||||
# **********Begin*********#
|
||||
##提示:补全损失函数的定义
|
||||
|
||||
# **********End**********#
|
||||
repeat_size = epoch_size
|
||||
# create the network
|
||||
network = LeNet5()
|
||||
# define the optimizer
|
||||
# 请在此添加代码完成本关任务
|
||||
# **********Begin*********#
|
||||
##提示:补全优化器的定义
|
||||
|
||||
# **********End**********#
|
||||
|
||||
# save the network model and parameters for subsequence fine-tuning
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
||||
# group layers into an object with training and evaluation features
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb)
|
||||
test_net(args, network, model, mnist_path)
|
Loading…
Reference in New Issue