forked from mindspore-Ecosystem/mindspore
add vgg scripts
This commit is contained in:
parent
930a1fb0a8
commit
31b165a57e
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in main.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
cifar_cfg = edict({
|
||||
'num_classes': 10,
|
||||
'lr_init': 0.05,
|
||||
'batch_size': 64,
|
||||
'epoch_size': 70,
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 5e-4,
|
||||
'buffer_size': 10,
|
||||
'image_height': 224,
|
||||
'image_width': 224,
|
||||
'keep_checkpoint_max': 10
|
||||
})
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import os
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
import mindspore.common.dtype as mstype
|
||||
from config import cifar_cfg as cfg
|
||||
|
||||
def create_dataset(data_home, repeat_num=1, training=True):
|
||||
"""Data operations."""
|
||||
ds.config.set_seed(1)
|
||||
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
|
||||
if not training:
|
||||
data_dir = os.path.join(data_home, "cifar-10-verify-bin")
|
||||
data_set = ds.Cifar10Dataset(data_dir)
|
||||
resize_height = cfg.image_height
|
||||
resize_width = cfg.image_width
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
|
||||
# define map operations
|
||||
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
|
||||
random_horizontal_op = vision.RandomHorizontalFlip()
|
||||
resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR
|
||||
rescale_op = vision.Rescale(rescale, shift)
|
||||
normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023))
|
||||
changeswap_op = vision.HWC2CHW()
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
|
||||
c_trans = []
|
||||
if training:
|
||||
c_trans = [random_crop_op, random_horizontal_op]
|
||||
c_trans += [resize_op, rescale_op, normalize_op,
|
||||
changeswap_op]
|
||||
|
||||
# apply map operations on images
|
||||
data_set = data_set.map(input_columns="label", operations=type_cast_op)
|
||||
data_set = data_set.map(input_columns="image", operations=c_trans)
|
||||
|
||||
# apply repeat operations
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
# apply shuffle operations
|
||||
data_set = data_set.shuffle(buffer_size=10)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True)
|
||||
|
||||
return data_set
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############test vgg16 example on cifar10#################
|
||||
python eval.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
|
||||
"""
|
||||
import argparse
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.model_zoo.vgg import vgg16
|
||||
from config import cifar_cfg as cfg
|
||||
import dataset
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Cifar10 classification')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default='./cifar', help='path where the dataset is saved')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='checkpoint file path.')
|
||||
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
context.set_context(enable_mem_reuse=True, enable_hccl=False)
|
||||
|
||||
net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
|
||||
weight_decay=cfg.weight_decay)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
dataset = dataset.create_dataset(args_opt.data_path, 1, False)
|
||||
res = model.eval(dataset)
|
||||
print("result: ", res)
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################train vgg16 example on cifar10########################
|
||||
python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
|
||||
"""
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore import context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.model_zoo.vgg import vgg16
|
||||
import dataset
|
||||
from config import cifar_cfg as cfg
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
|
||||
def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
|
||||
"""Set learning rate."""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
for i in range(total_steps):
|
||||
if i < decay_epoch_index[0]:
|
||||
lr_each_step.append(lr_max)
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr_each_step.append(lr_max * 0.1)
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr_each_step.append(lr_max * 0.01)
|
||||
else:
|
||||
lr_each_step.append(lr_max * 0.001)
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Cifar10 classification')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default='./cifar', help='path where the dataset is saved')
|
||||
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
context.set_context(enable_mem_reuse=True, enable_hccl=False)
|
||||
|
||||
net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes)
|
||||
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=50000 // cfg.batch_size)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
dataset = dataset.create_dataset(args_opt.data_path, cfg.epoch_size)
|
||||
batch_num = dataset.get_dataset_size()
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_vgg_cifar10", directory="./", config=config_ck)
|
||||
loss_cb = LossMonitor()
|
||||
model.train(cfg.epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
|
|
@ -15,7 +15,8 @@
|
|||
"""VGG."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from mindspore.common.initializer import initializer
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
def _make_layer(base, batch_norm):
|
||||
"""Make stage network of VGG."""
|
||||
|
@ -25,11 +26,14 @@ def _make_layer(base, batch_norm):
|
|||
if v == 'M':
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
weight_shape = (v, in_channels, 3, 3)
|
||||
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32)
|
||||
conv2d = nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=v,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
pad_mode='pad')
|
||||
padding=0,
|
||||
pad_mode='same',
|
||||
weight_init=weight)
|
||||
if batch_norm:
|
||||
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()]
|
||||
else:
|
||||
|
@ -52,13 +56,13 @@ class Vgg(nn.Cell):
|
|||
Tensor, infer output tensor.
|
||||
|
||||
Examples:
|
||||
>>> VGG([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
||||
>>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
||||
>>> num_classes=1000, batch_norm=False, batch_size=1)
|
||||
"""
|
||||
|
||||
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1):
|
||||
super(Vgg, self).__init__()
|
||||
self.layers = _make_layer(base, batch_norm=batch_norm)
|
||||
self.avgpool = nn.AvgPool2d(7)
|
||||
self.reshape = P.Reshape()
|
||||
self.shp = (batch_size, -1)
|
||||
self.classifier = nn.SequentialCell([
|
||||
|
@ -70,7 +74,6 @@ class Vgg(nn.Cell):
|
|||
|
||||
def construct(self, x):
|
||||
x = self.layers(x)
|
||||
x = self.avgpool(x)
|
||||
x = self.reshape(x, self.shp)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
@ -84,15 +87,20 @@ cfg = {
|
|||
}
|
||||
|
||||
|
||||
def vgg16():
|
||||
def vgg16(batch_size=1, num_classes=1000):
|
||||
"""
|
||||
Get VGG16 neural network.
|
||||
Get Vgg16 neural network with batch normalization.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size. Default: 1.
|
||||
num_classes (int): Class numbers. Default: 1000.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of VGG16 neural network.
|
||||
Cell, cell instance of Vgg16 neural network with batch normalization.
|
||||
|
||||
Examples:
|
||||
>>> vgg16()
|
||||
>>> vgg16(batch_size=1, num_classes=1000)
|
||||
"""
|
||||
net = Vgg(cfg['16'], num_classes=1000, batch_norm=False, batch_size=1)
|
||||
|
||||
net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True, batch_size=batch_size)
|
||||
return net
|
||||
|
|
Loading…
Reference in New Issue