forked from OSSInnovation/mindspore
!3839 GPU: inceptionv3 support in modelzoo
Merge pull request !3839 from hhc2020/gpu-inceptionv3-modelzoo
This commit is contained in:
commit
ad3d490d1e
|
@ -0,0 +1,115 @@
|
|||
# Inception-v3 Example
|
||||
|
||||
## Description
|
||||
|
||||
This is an example of training Inception-v3 in MindSpore.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install [Mindspore](http://www.mindspore.cn/install/en).
|
||||
- Downlaod the dataset.
|
||||
|
||||
## Structure
|
||||
|
||||
```shell
|
||||
.
|
||||
└─Inception-v3
|
||||
├─README.md
|
||||
├─scripts
|
||||
├─run_standalone_train_for_gpu.sh # launch standalone training with gpu platform(1p)
|
||||
├─run_distribute_train_for_gpu.sh # launch distributed training with gpu platform(8p)
|
||||
└─run_eval_for_gpu.sh # launch evaluating with gpu platform
|
||||
├─src
|
||||
├─config.py # parameter configuration
|
||||
├─dataset.py # data preprocessing
|
||||
├─inception_v3.py # network definition
|
||||
├─loss.py # Customized CrossEntropy loss function
|
||||
├─lr_generator.py # learning rate generator
|
||||
├─eval.py # eval net
|
||||
├─export.py # convert checkpoint
|
||||
└─train.py # train net
|
||||
|
||||
```
|
||||
|
||||
## Parameter Configuration
|
||||
|
||||
Parameters for both training and evaluating can be set in config.py
|
||||
|
||||
```
|
||||
'random_seed': 1, # fix random seed
|
||||
'rank': 0, # local rank of distributed
|
||||
'group_size': 1, # world size of distributed
|
||||
'work_nums': 8, # number of workers to read the data
|
||||
'decay_method': 'cosine', # learning rate scheduler mode
|
||||
"loss_scale": 1, # loss scale
|
||||
'batch_size': 128, # input batchsize
|
||||
'epoch_size': 250, # total epoch numbers
|
||||
'num_classes': 1000, # dataset class numbers
|
||||
'smooth_factor': 0.1, # label smoothing factor
|
||||
'aux_factor': 0.2, # loss factor of aux logit
|
||||
'lr_init': 0.00004, # initiate learning rate
|
||||
'lr_max': 0.4, # max bound of learning rate
|
||||
'lr_end': 0.000004, # min bound of learning rate
|
||||
'warmup_epochs': 1, # warmup epoch numbers
|
||||
'weight_decay': 0.00004, # weight decay
|
||||
'momentum': 0.9, # momentum
|
||||
'opt_eps': 1.0, # epsilon
|
||||
'keep_checkpoint_max': 100, # max numbers to keep checkpoints
|
||||
'ckpt_path': './checkpoint/', # save checkpoint path
|
||||
'is_save_on_master': 1 # save checkpoint on rank0, distributed parameters
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Running the example
|
||||
|
||||
### Train
|
||||
|
||||
#### Usage
|
||||
|
||||
```
|
||||
# distribute training example(8p)
|
||||
sh run_distribute_train_for_gpu.sh DATA_DIR
|
||||
# standalone training
|
||||
sh run_standalone_train_for_gpu.sh DEVICE_ID DATA_DIR
|
||||
```
|
||||
|
||||
#### Launch
|
||||
|
||||
```bash
|
||||
# distributed training example(8p) for GPU
|
||||
sh scripts/run_distribute_train_for_gpu.sh /dataset/train
|
||||
# standalone training example for GPU
|
||||
sh scripts/run_standalone_train_for_gpu.sh 0 /dataset/train
|
||||
```
|
||||
|
||||
#### Result
|
||||
|
||||
You can find checkpoint file together with result in log.
|
||||
|
||||
### Evaluation
|
||||
|
||||
#### Usage
|
||||
|
||||
```
|
||||
# Evaluation
|
||||
sh run_eval_for_gpu.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT
|
||||
```
|
||||
|
||||
#### Launch
|
||||
|
||||
```bash
|
||||
# Evaluation with checkpoint
|
||||
sh scripts/run_eval_for_gpu.sh 0 /dataset/val ./checkpoint/inceptionv3-rank3-247_1251.ckpt
|
||||
```
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
|
||||
#### Result
|
||||
|
||||
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.
|
||||
|
||||
```
|
||||
acc=78.75%(TOP1)
|
||||
acc=94.07%(TOP5)
|
||||
```
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""evaluate_imagenet"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config_gpu as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.inception_v3 import InceptionV3
|
||||
from src.loss import CrossEntropy_Val
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='image classification evaluation')
|
||||
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of inception-v3 (Default: None)')
|
||||
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
|
||||
net = InceptionV3(num_classes=cfg.num_classes, is_training=False)
|
||||
ckpt = load_checkpoint(args_opt.checkpoint)
|
||||
load_param_into_net(net, ckpt)
|
||||
net.set_train(False)
|
||||
dataset = create_dataset(args_opt.dataset_path, False, 0, 1)
|
||||
loss = CrossEntropy_Val(smooth_factor=0.1, num_classes=cfg.num_classes)
|
||||
eval_metrics = {'Loss': nn.Loss(),
|
||||
'Top1-Acc': nn.Top1CategoricalAccuracy(),
|
||||
'Top5-Acc': nn.Top5CategoricalAccuracy()}
|
||||
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
|
||||
metrics = model.eval(dataset)
|
||||
print("metric: ", metrics)
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into geir and onnx models#################
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import config_gpu as cfg
|
||||
from src.inception_v3 import InceptionV3
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='checkpoint export')
|
||||
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of inception-v3 (Default: None)')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
net = InceptionV3(num_classes=cfg.num_classes, is_training=False)
|
||||
param_dict = load_checkpoint(args_opt.checkpoint)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 299, 299]), ms.float32)
|
||||
export(net, input_arr, file_name=cfg.onnx_filename, file_format="ONNX")
|
||||
export(net, input_arr, file_name=cfg.geir_filename, file_format="GEIR")
|
|
@ -0,0 +1,17 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATA_DIR=$1
|
||||
mpirun --allow-run-as-root -n 8 python ./train.py --is_distributed --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 &
|
|
@ -0,0 +1,19 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
PATH_CHECKPOINT=$3
|
||||
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./eval.py --platform 'GPU' --dataset_path $DATA_DIR --checkpoint $PATH_CHECKPOINT > eval.log 2>&1 &
|
|
@ -0,0 +1,19 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./train.py --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 &
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in main.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
|
||||
config_gpu = edict({
|
||||
'random_seed': 1,
|
||||
'rank': 0,
|
||||
'group_size': 1,
|
||||
'work_nums': 8,
|
||||
'decay_method': 'cosine',
|
||||
"loss_scale": 1,
|
||||
'batch_size': 128,
|
||||
'epoch_size': 250,
|
||||
'num_classes': 1000,
|
||||
'smooth_factor': 0.1,
|
||||
'aux_factor': 0.2,
|
||||
'lr_init': 0.00004,
|
||||
'lr_max': 0.4,
|
||||
'lr_end': 0.000004,
|
||||
'warmup_epochs': 1,
|
||||
'weight_decay': 0.00004,
|
||||
'momentum': 0.9,
|
||||
'opt_eps': 1.0,
|
||||
'keep_checkpoint_max': 100,
|
||||
'ckpt_path': './checkpoint/',
|
||||
'is_save_on_master': 0
|
||||
})
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
from src.config import config_gpu as cfg
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train, rank, group_size, repeat_num=1):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
rank (int): The shard ID within num_shards (default=None).
|
||||
group_size (int): Number of shards that the dataset should be divided into (default=None).
|
||||
repeat_num(int): the repeat times of dataset. Default: 1.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if group_size == 1:
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True)
|
||||
else:
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True,
|
||||
num_shards=group_size, shard_id=rank)
|
||||
# define map operations
|
||||
if do_train:
|
||||
trans = [
|
||||
C.RandomCropDecodeResize(299, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize(299),
|
||||
C.CenterCrop(299)
|
||||
]
|
||||
trans += [
|
||||
C.Rescale(1.0 / 255.0, 0.0),
|
||||
C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=cfg.work_nums)
|
||||
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=cfg.work_nums)
|
||||
# apply batch operations
|
||||
ds = ds.batch(cfg.batch_size, drop_remainder=True)
|
||||
# apply dataset repeat operation
|
||||
ds = ds.repeat(repeat_num)
|
||||
return ds
|
|
@ -0,0 +1,257 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Inception-v3 model definition"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import XavierUniform
|
||||
|
||||
|
||||
class BasicConv2d(nn.Cell):
|
||||
def __init__(self, in_channel, out_channel, kernel_size, stride=1, pad_mode='same', padding=0):
|
||||
super(BasicConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,
|
||||
pad_mode=pad_mode, padding=padding, weight_init=XavierUniform(), has_bias=True)
|
||||
self.bn = nn.BatchNorm2d(out_channel, eps=0.001, momentum=0.9997)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Inception_A(nn.Cell):
|
||||
def __init__(self, in_channels, pool_features):
|
||||
super(Inception_A, self).__init__()
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.branch0 = BasicConv2d(in_channels, 64, kernel_size=1)
|
||||
self.branch1 = nn.SequentialCell([
|
||||
BasicConv2d(in_channels, 48, kernel_size=1),
|
||||
BasicConv2d(48, 64, kernel_size=5)
|
||||
])
|
||||
self.branch2 = nn.SequentialCell([
|
||||
BasicConv2d(in_channels, 64, kernel_size=1),
|
||||
BasicConv2d(64, 96, kernel_size=3),
|
||||
BasicConv2d(96, 96, kernel_size=3)
|
||||
|
||||
])
|
||||
self.branch_pool = nn.SequentialCell([
|
||||
nn.AvgPool2d(kernel_size=3, pad_mode='same'),
|
||||
BasicConv2d(in_channels, pool_features, kernel_size=1)
|
||||
])
|
||||
|
||||
def construct(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
branch_pool = self.branch_pool(x)
|
||||
out = self.concat((x0, x1, x2, branch_pool))
|
||||
return out
|
||||
|
||||
|
||||
class Inception_B(nn.Cell):
|
||||
def __init__(self, in_channels):
|
||||
super(Inception_B, self).__init__()
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.branch0 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2, pad_mode='valid')
|
||||
self.branch1 = nn.SequentialCell([
|
||||
BasicConv2d(in_channels, 64, kernel_size=1),
|
||||
BasicConv2d(64, 96, kernel_size=3),
|
||||
BasicConv2d(96, 96, kernel_size=3, stride=2, pad_mode='valid')
|
||||
|
||||
])
|
||||
self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2)
|
||||
|
||||
def construct(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
branch_pool = self.branch_pool(x)
|
||||
out = self.concat((x0, x1, branch_pool))
|
||||
return out
|
||||
|
||||
|
||||
class Inception_C(nn.Cell):
|
||||
def __init__(self, in_channels, channels_7x7):
|
||||
super(Inception_C, self).__init__()
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.branch0 = BasicConv2d(in_channels, 192, kernel_size=1)
|
||||
self.branch1 = nn.SequentialCell([
|
||||
BasicConv2d(in_channels, channels_7x7, kernel_size=1),
|
||||
BasicConv2d(channels_7x7, channels_7x7, kernel_size=(1, 7)),
|
||||
BasicConv2d(channels_7x7, 192, kernel_size=(7, 1))
|
||||
])
|
||||
self.branch2 = nn.SequentialCell([
|
||||
BasicConv2d(in_channels, channels_7x7, kernel_size=1),
|
||||
BasicConv2d(channels_7x7, channels_7x7, kernel_size=(7, 1)),
|
||||
BasicConv2d(channels_7x7, channels_7x7, kernel_size=(1, 7)),
|
||||
BasicConv2d(channels_7x7, channels_7x7, kernel_size=(7, 1)),
|
||||
BasicConv2d(channels_7x7, 192, kernel_size=(1, 7))
|
||||
])
|
||||
self.branch_pool = nn.SequentialCell([
|
||||
nn.AvgPool2d(kernel_size=3, pad_mode='same'),
|
||||
BasicConv2d(in_channels, 192, kernel_size=1)
|
||||
])
|
||||
|
||||
def construct(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
branch_pool = self.branch_pool(x)
|
||||
out = self.concat((x0, x1, x2, branch_pool))
|
||||
return out
|
||||
|
||||
|
||||
class Inception_D(nn.Cell):
|
||||
def __init__(self, in_channels):
|
||||
super(Inception_D, self).__init__()
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.branch0 = nn.SequentialCell([
|
||||
BasicConv2d(in_channels, 192, kernel_size=1),
|
||||
BasicConv2d(192, 320, kernel_size=3, stride=2, pad_mode='valid')
|
||||
])
|
||||
self.branch1 = nn.SequentialCell([
|
||||
BasicConv2d(in_channels, 192, kernel_size=1),
|
||||
BasicConv2d(192, 192, kernel_size=(1, 7)), # check
|
||||
BasicConv2d(192, 192, kernel_size=(7, 1)),
|
||||
BasicConv2d(192, 192, kernel_size=3, stride=2, pad_mode='valid')
|
||||
])
|
||||
self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2)
|
||||
|
||||
def construct(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
branch_pool = self.branch_pool(x)
|
||||
out = self.concat((x0, x1, branch_pool))
|
||||
return out
|
||||
|
||||
|
||||
class Inception_E(nn.Cell):
|
||||
def __init__(self, in_channels):
|
||||
super(Inception_E, self).__init__()
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.branch0 = BasicConv2d(in_channels, 320, kernel_size=1)
|
||||
self.branch1 = BasicConv2d(in_channels, 384, kernel_size=1)
|
||||
self.branch1_a = BasicConv2d(384, 384, kernel_size=(1, 3))
|
||||
self.branch1_b = BasicConv2d(384, 384, kernel_size=(3, 1))
|
||||
self.branch2 = nn.SequentialCell([
|
||||
BasicConv2d(in_channels, 448, kernel_size=1),
|
||||
BasicConv2d(448, 384, kernel_size=3)
|
||||
])
|
||||
self.branch2_a = BasicConv2d(384, 384, kernel_size=(1, 3))
|
||||
self.branch2_b = BasicConv2d(384, 384, kernel_size=(3, 1))
|
||||
self.branch_pool = nn.SequentialCell([
|
||||
nn.AvgPool2d(kernel_size=3, pad_mode='same'),
|
||||
BasicConv2d(in_channels, 192, kernel_size=1)
|
||||
])
|
||||
|
||||
def construct(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x1 = self.concat((self.branch1_a(x1), self.branch1_b(x1)))
|
||||
x2 = self.branch2(x)
|
||||
x2 = self.concat((self.branch2_a(x2), self.branch2_b(x2)))
|
||||
branch_pool = self.branch_pool(x)
|
||||
out = self.concat((x0, x1, x2, branch_pool))
|
||||
return out
|
||||
|
||||
|
||||
class Logits(nn.Cell):
|
||||
def __init__(self, num_classes=10, dropout_keep_prob=0.8):
|
||||
super(Logits, self).__init__()
|
||||
self.avg_pool = nn.AvgPool2d(8, pad_mode='valid')
|
||||
self.dropout = nn.Dropout(keep_prob=dropout_keep_prob)
|
||||
self.flatten = P.Flatten()
|
||||
self.fc = nn.Dense(2048, num_classes)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.avg_pool(x)
|
||||
x = self.dropout(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class AuxLogits(nn.Cell):
|
||||
def __init__(self, in_channels, num_classes=10):
|
||||
super(AuxLogits, self).__init__()
|
||||
self.avg_pool = nn.AvgPool2d(5, stride=3, pad_mode='valid')
|
||||
self.conv2d_0 = nn.Conv2d(in_channels, 128, kernel_size=1)
|
||||
self.conv2d_1 = nn.Conv2d(128, 768, kernel_size=5, pad_mode='valid')
|
||||
self.flatten = P.Flatten()
|
||||
self.fc = nn.Dense(in_channels, num_classes)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.avg_pool(x)
|
||||
x = self.conv2d_0(x)
|
||||
x = self.conv2d_1(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class InceptionV3(nn.Cell):
|
||||
def __init__(self, num_classes=10, is_training=True):
|
||||
super(InceptionV3, self).__init__()
|
||||
self.is_training = is_training
|
||||
self.Conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2, pad_mode='valid')
|
||||
self.Conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1, pad_mode='valid')
|
||||
self.Conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1)
|
||||
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
|
||||
self.Conv2d_3b = BasicConv2d(64, 80, kernel_size=1)
|
||||
self.Conv2d_4a = BasicConv2d(80, 192, kernel_size=3, pad_mode='valid')
|
||||
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
|
||||
self.Mixed_5b = Inception_A(192, pool_features=32)
|
||||
self.Mixed_5c = Inception_A(256, pool_features=64)
|
||||
self.Mixed_5d = Inception_A(288, pool_features=64)
|
||||
self.Mixed_6a = Inception_B(288)
|
||||
self.Mixed_6b = Inception_C(768, channels_7x7=128)
|
||||
self.Mixed_6c = Inception_C(768, channels_7x7=160)
|
||||
self.Mixed_6d = Inception_C(768, channels_7x7=160)
|
||||
self.Mixed_6e = Inception_C(768, channels_7x7=192)
|
||||
self.Mixed_7a = Inception_D(768)
|
||||
self.Mixed_7b = Inception_E(1280)
|
||||
self.Mixed_7c = Inception_E(2048)
|
||||
if is_training:
|
||||
self.aux_logits = AuxLogits(768, num_classes)
|
||||
self.logits = Logits(num_classes, dropout_keep_prob=0.5)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.Conv2d_1a(x)
|
||||
x = self.Conv2d_2a(x)
|
||||
x = self.Conv2d_2b(x)
|
||||
x = self.maxpool1(x)
|
||||
x = self.Conv2d_3b(x)
|
||||
x = self.Conv2d_4a(x)
|
||||
x = self.maxpool2(x)
|
||||
x = self.Mixed_5b(x)
|
||||
x = self.Mixed_5c(x)
|
||||
x = self.Mixed_5d(x)
|
||||
x = self.Mixed_6a(x)
|
||||
x = self.Mixed_6b(x)
|
||||
x = self.Mixed_6c(x)
|
||||
x = self.Mixed_6d(x)
|
||||
x = self.Mixed_6e(x)
|
||||
if self.is_training:
|
||||
aux_logits = self.aux_logits(x)
|
||||
else:
|
||||
aux_logits = None
|
||||
x = self.Mixed_7a(x)
|
||||
x = self.Mixed_7b(x)
|
||||
x = self.Mixed_7c(x)
|
||||
logits = self.logits(x)
|
||||
if self.is_training:
|
||||
return logits, aux_logits
|
||||
return logits
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define loss function for network."""
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class CrossEntropy(_Loss):
|
||||
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
|
||||
def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4):
|
||||
super(CrossEntropy, self).__init__()
|
||||
self.factor = factor
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logits, label):
|
||||
logit, aux = logits
|
||||
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss_logit = self.ce(logit, one_hot_label)
|
||||
loss_logit = self.mean(loss_logit, 0)
|
||||
one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value)
|
||||
loss_aux = self.ce(aux, one_hot_label_aux)
|
||||
loss_aux = self.mean(loss_aux, 0)
|
||||
return loss_logit + self.factor*loss_aux
|
||||
|
||||
|
||||
class CrossEntropy_Val(_Loss):
|
||||
"""the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process"""
|
||||
def __init__(self, smooth_factor=0, num_classes=1000):
|
||||
super(CrossEntropy_Val, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logits, label):
|
||||
one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value)
|
||||
loss_logit = self.ce(logits, one_hot_label)
|
||||
loss_logit = self.mean(loss_logit, 0)
|
||||
return loss_logit
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
if lr_decay_mode == 'steps':
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr = lr_max * 0.01
|
||||
else:
|
||||
lr = lr_max * 0.001
|
||||
lr_each_step.append(lr)
|
||||
elif lr_decay_mode == 'steps_decay':
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
decay_nums = math.floor((float(i-warmup_steps)/steps_per_epoch) / 2)
|
||||
decay_rate = pow(0.94, decay_nums)
|
||||
lr = float(lr_max)*decay_rate
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
elif lr_decay_mode == 'cosine':
|
||||
decay_steps = total_steps - warmup_steps
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * (i-warmup_steps) / decay_steps))
|
||||
lr = (lr_max-lr_end)*cosine_decay + lr_end
|
||||
lr_each_step.append(lr)
|
||||
else:
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lr)
|
||||
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||
return learning_rate
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_imagenet."""
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.nn.optim.rmsprop import RMSProp
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore import dataset as de
|
||||
|
||||
from src.config import config_gpu as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.inception_v3 import InceptionV3
|
||||
from src.lr_generator import get_lr
|
||||
from src.loss import CrossEntropy
|
||||
|
||||
random.seed(cfg.random_seed)
|
||||
np.random.seed(cfg.random_seed)
|
||||
de.config.set_seed(cfg.random_seed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='image classification training')
|
||||
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
||||
parser.add_argument('--is_distributed', action='store_true', default=False,
|
||||
help='distributed training')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
|
||||
# init distributed
|
||||
if args_opt.is_distributed:
|
||||
if args_opt.platform == "Ascend":
|
||||
init()
|
||||
else:
|
||||
init("nccl")
|
||||
cfg.rank = get_rank()
|
||||
cfg.group_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
|
||||
parameter_broadcast=True, mirror_mean=True)
|
||||
else:
|
||||
cfg.rank = 0
|
||||
cfg.group_size = 1
|
||||
|
||||
# dataloader
|
||||
dataset = create_dataset(args_opt.dataset_path, True, cfg.rank, cfg.group_size)
|
||||
batches_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
# network
|
||||
net = InceptionV3(num_classes=cfg.num_classes)
|
||||
|
||||
# loss
|
||||
loss = CrossEntropy(smooth_factor=cfg.smooth_factor, num_classes=cfg.num_classes, factor=cfg.aux_factor)
|
||||
|
||||
# learning rate schedule
|
||||
lr = get_lr(lr_init=cfg.lr_init, lr_end=cfg.lr_end, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs,
|
||||
total_epochs=cfg.epoch_size, steps_per_epoch=batches_per_epoch, lr_decay_mode=cfg.decay_method)
|
||||
lr = Tensor(lr)
|
||||
|
||||
# optimizer
|
||||
decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainable_params()))
|
||||
no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params]
|
||||
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': net.trainable_params()}]
|
||||
optimizer = RMSProp(group_params, lr, decay=0.9, weight_decay=cfg.weight_decay,
|
||||
momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale)
|
||||
eval_metrics = {'Loss': nn.Loss(),
|
||||
'Top1-Acc': nn.Top1CategoricalAccuracy(),
|
||||
'Top5-Acc': nn.Top5CategoricalAccuracy()}
|
||||
|
||||
if args_opt.resume:
|
||||
ckpt = load_checkpoint(args_opt.resume)
|
||||
load_param_into_net(net, ckpt)
|
||||
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={'acc'})
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
|
||||
time_cb = TimeMonitor(data_size=batches_per_epoch)
|
||||
callbacks = [loss_cb, time_cb]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionv3-rank{cfg.rank}", directory=cfg.ckpt_path, config=config_ck)
|
||||
if args_opt.is_distributed & cfg.is_save_on_master:
|
||||
if cfg.rank == 0:
|
||||
callbacks.append(ckpoint_cb)
|
||||
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
else:
|
||||
callbacks.append(ckpoint_cb)
|
||||
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
print("train success")
|
Loading…
Reference in New Issue