forked from mindspore-Ecosystem/mindspore
!5670 modify lenet network
Merge pull request !5670 from wukesong/lenet_network
This commit is contained in:
commit
1c3fc5c49b
|
@ -14,27 +14,6 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""LeNet."""
|
"""LeNet."""
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.common.initializer import TruncatedNormal
|
|
||||||
|
|
||||||
|
|
||||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
|
||||||
"""weight initial for conv layer"""
|
|
||||||
weight = weight_variable()
|
|
||||||
return nn.Conv2d(in_channels, out_channels,
|
|
||||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
weight_init=weight, has_bias=False, pad_mode="valid")
|
|
||||||
|
|
||||||
|
|
||||||
def fc_with_initialize(input_channels, out_channels):
|
|
||||||
"""weight initial for fc layer"""
|
|
||||||
weight = weight_variable()
|
|
||||||
bias = weight_variable()
|
|
||||||
return nn.Dense(input_channels, out_channels, weight, bias)
|
|
||||||
|
|
||||||
|
|
||||||
def weight_variable():
|
|
||||||
"""weight initial"""
|
|
||||||
return TruncatedNormal(0.02)
|
|
||||||
|
|
||||||
|
|
||||||
class LeNet5(nn.Cell):
|
class LeNet5(nn.Cell):
|
||||||
|
@ -43,6 +22,7 @@ class LeNet5(nn.Cell):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_class (int): Num classes. Default: 10.
|
num_class (int): Num classes. Default: 10.
|
||||||
|
channel (int): Num classes. Default: 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, output tensor
|
Tensor, output tensor
|
||||||
|
@ -53,26 +33,20 @@ class LeNet5(nn.Cell):
|
||||||
def __init__(self, num_class=10, channel=1):
|
def __init__(self, num_class=10, channel=1):
|
||||||
super(LeNet5, self).__init__()
|
super(LeNet5, self).__init__()
|
||||||
self.num_class = num_class
|
self.num_class = num_class
|
||||||
self.conv1 = conv(channel, 6, 5)
|
self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid')
|
||||||
self.conv2 = conv(6, 16, 5)
|
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
|
||||||
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
self.fc1 = nn.Dense(16 * 5 * 5, 120)
|
||||||
self.fc2 = fc_with_initialize(120, 84)
|
self.fc2 = nn.Dense(120, 84)
|
||||||
self.fc3 = fc_with_initialize(84, self.num_class)
|
self.fc3 = nn.Dense(84, self.num_class)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||||
self.flatten = nn.Flatten()
|
self.flatten = nn.Flatten()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x = self.conv1(x)
|
x = self.max_pool2d(self.relu(self.conv1(x)))
|
||||||
x = self.relu(x)
|
x = self.max_pool2d(self.relu(self.conv2(x)))
|
||||||
x = self.max_pool2d(x)
|
|
||||||
x = self.conv2(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.max_pool2d(x)
|
|
||||||
x = self.flatten(x)
|
x = self.flatten(x)
|
||||||
x = self.fc1(x)
|
x = self.relu(self.fc1(x))
|
||||||
x = self.relu(x)
|
x = self.relu(self.fc2(x))
|
||||||
x = self.fc2(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.fc3(x)
|
x = self.fc3(x)
|
||||||
return x
|
return x
|
||||||
|
|
Loading…
Reference in New Issue