forked from mindspore-Ecosystem/mindspore
remove the parameter batch_size of VGG16, for we can use flatten instead of reshape.
This commit is contained in:
parent
ebd0fd33f6
commit
b36094e327
|
@ -39,7 +39,7 @@ if __name__ == '__main__':
|
|||
context.set_context(device_id=args_opt.device_id)
|
||||
context.set_context(enable_mem_reuse=True, enable_hccl=False)
|
||||
|
||||
net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes)
|
||||
net = vgg16(num_classes=cfg.num_classes)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
|
||||
weight_decay=cfg.weight_decay)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
|
||||
|
|
|
@ -64,7 +64,7 @@ if __name__ == '__main__':
|
|||
context.set_context(device_id=args_opt.device_id)
|
||||
context.set_context(enable_mem_reuse=True, enable_hccl=False)
|
||||
|
||||
net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes)
|
||||
net = vgg16(num_classes=cfg.num_classes)
|
||||
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=50000 // cfg.batch_size)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
"""VGG."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import initializer
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
@ -63,8 +62,7 @@ class Vgg(nn.Cell):
|
|||
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1):
|
||||
super(Vgg, self).__init__()
|
||||
self.layers = _make_layer(base, batch_norm=batch_norm)
|
||||
self.reshape = P.Reshape()
|
||||
self.shp = (batch_size, -1)
|
||||
self.flatten = nn.Flatten()
|
||||
self.classifier = nn.SequentialCell([
|
||||
nn.Dense(512 * 7 * 7, 4096),
|
||||
nn.ReLU(),
|
||||
|
@ -74,7 +72,7 @@ class Vgg(nn.Cell):
|
|||
|
||||
def construct(self, x):
|
||||
x = self.layers(x)
|
||||
x = self.reshape(x, self.shp)
|
||||
x = self.flatten(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
@ -87,20 +85,19 @@ cfg = {
|
|||
}
|
||||
|
||||
|
||||
def vgg16(batch_size=1, num_classes=1000):
|
||||
def vgg16(num_classes=1000):
|
||||
"""
|
||||
Get Vgg16 neural network with batch normalization.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size. Default: 1.
|
||||
num_classes (int): Class numbers. Default: 1000.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of Vgg16 neural network with batch normalization.
|
||||
|
||||
Examples:
|
||||
>>> vgg16(batch_size=1, num_classes=1000)
|
||||
>>> vgg16(num_classes=1000)
|
||||
"""
|
||||
|
||||
net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True, batch_size=batch_size)
|
||||
net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True)
|
||||
return net
|
||||
|
|
Loading…
Reference in New Issue