From a3fe77e5b366c96d0d2f4b69a78b081d3ed1aa9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A9=AC=E5=A4=A9=E5=A7=BF?= Date: Wed, 8 Jul 2020 02:04:58 +0800 Subject: [PATCH] =?UTF-8?q?step1=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- MindSpore/src/step1/net.py | 53 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 MindSpore/src/step1/net.py diff --git a/MindSpore/src/step1/net.py b/MindSpore/src/step1/net.py new file mode 100644 index 0000000..66b3da3 --- /dev/null +++ b/MindSpore/src/step1/net.py @@ -0,0 +1,53 @@ + +import mindspore.nn as nn +from mindspore.common.initializer import TruncatedNormal + +def weight_variable(): + """ + weight initial + """ + return TruncatedNormal(0.02) + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """ + conv layer weight initial + """ + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + +def fc_with_initialize(input_channels, out_channels): + """ + fc layer weight initial + """ + # 请在此添加代码完成本关任务 + # **********Begin*********# + ##提示:完成初始化代码 + + # **********End**********# + +class LeNet5(nn.Cell): + """ + Lenet network structure + """ + #define the operator required + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, 10) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + #use the preceding operators to construct networks + def construct(self, x): + # 请在此添加代码完成本关任务 + # **********Begin*********# + ##提示:根据教程内容完成网络定义即可 + + # **********End**********# + return x