forked from mindspore-Ecosystem/mindspore
!11029 FCN8s
From: @zhu_wenyong Reviewed-by: @linqingke,@oacjiewen Signed-off-by: @linqingke
This commit is contained in:
commit
d4ef0452a6
|
@ -0,0 +1,296 @@
|
|||
# Contents
|
||||
|
||||
- [FCN 介绍](#FCN-介绍)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速开始](#快速开始)
|
||||
- [脚本介绍](#脚本介绍)
|
||||
- [脚本以及简单代码](#脚本以及简单代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练步骤](#训练步骤)
|
||||
- [训练](#训练)
|
||||
- [评估步骤](#评估步骤)
|
||||
- [评估](#评估)
|
||||
- [模型介绍](#模型介绍)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [如何使用](#如何使用)
|
||||
- [教程](#教程)
|
||||
- [随机事件介绍](#随机事件介绍)
|
||||
- [ModelZoo 主页](#ModelZoo-主页)
|
||||
|
||||
# [FCN 介绍](#contents)
|
||||
|
||||
FCN主要用用于图像分割领域,是一种端到端的分割方法。FCN丢弃了全连接层,使得其能够处理任意大小的图像,且减少了模型的参数量,提高了模型的分割速度。FCN在编码部分使用了VGG的结构,在解码部分中使用反卷积/上采样操作恢复图像的分辨率。FCN-8s最后使用8倍的反卷积/上采样操作将输出分割图恢复到与输入图像相同大小。
|
||||
|
||||
[Paper]: Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks for semantic segmentation." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
|
||||
|
||||
# [模型架构](#contents)
|
||||
|
||||
FCN-8s使用丢弃全连接操作的VGG16作为编码部分,并分别融合VGG16中第3,4,5个池化层特征,最后使用stride=8的反卷积获得分割图像。
|
||||
|
||||
# [数据集](#contents)
|
||||
|
||||
Dataset used:
|
||||
|
||||
[PASCAL VOC 2012](<http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html>)
|
||||
|
||||
[SBD](<http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz>)
|
||||
|
||||
# [环境要求](#contents)
|
||||
|
||||
- 硬件(Ascend/GPU)
|
||||
- 需要准备具有Ascend或GPU处理能力的硬件环境. 如需使用Ascend,可以发送 [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) 到ascend@huawei.com。一旦批准,你就可以使用此资源
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需获取更多信息,请查看如下链接:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [快速开始](#contents)
|
||||
|
||||
在通过官方网站安装MindSpore之后,你可以通过如下步骤开始训练以及评估:
|
||||
|
||||
- runing on Ascend with default paramaters
|
||||
|
||||
```python
|
||||
# run training example
|
||||
python train.py --device_id device_id
|
||||
|
||||
# run evaluation example with default paramaters
|
||||
python eval.py --device_id device_id
|
||||
```
|
||||
|
||||
# [脚本介绍](#contents)
|
||||
|
||||
## [脚本以及简单代码](#contents)
|
||||
|
||||
```python
|
||||
├── model_zoo
|
||||
├── README.md // descriptions about all the models
|
||||
├── FCN8s
|
||||
├── README.md // descriptions about FCN
|
||||
├── scripts
|
||||
├── run_train.sh
|
||||
├── run_eval.sh
|
||||
├── build_data.sh
|
||||
├── src
|
||||
│ ├──data
|
||||
│ ├──build_seg_data.py // creating dataset
|
||||
│ ├──dataset.py // loading dataset
|
||||
│ ├──nets
|
||||
│ ├──FCN8s.py // FCN-8s architecture
|
||||
│ ├──loss
|
||||
│ ├──loss.py // loss function
|
||||
│ ├──utils
|
||||
│ ├──lr_scheduler.py // getting learning_rateFCN-8s
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
```
|
||||
|
||||
## [脚本参数](#contents)
|
||||
|
||||
训练以及评估的参数可以在config.py中设置
|
||||
|
||||
- config for FCN8s
|
||||
|
||||
```python
|
||||
# dataset
|
||||
'data_file': '/data/workspace/mindspore_dataset/FCN/FCN/dataset/MINDRECORED_NAME.mindrecord', # path and name of one mindrecord file
|
||||
'batch_size': 32,
|
||||
'crop_size': 512,
|
||||
'image_mean': [103.53, 116.28, 123.675],
|
||||
'image_std': [57.375, 57.120, 58.395],
|
||||
'min_scale': 0.5,
|
||||
'max_scale': 2.0,
|
||||
'ignore_label': 255,
|
||||
'num_classes': 21,
|
||||
|
||||
# optimizer
|
||||
'train_epochs': 500,
|
||||
'base_lr': 0.015,
|
||||
'loss_scale': 1024.0,
|
||||
|
||||
# model
|
||||
'model': 'FCN8s',
|
||||
'ckpt_vgg16': '/data/workspace/mindspore_dataset/FCN/FCN/model/0-150_5004.ckpt',
|
||||
'ckpt_pre_trained': '/data/workspace/mindspore_dataset/FCN/FCN/model_new/FCN8s-500_82.ckpt',
|
||||
|
||||
# train
|
||||
'save_steps': 330,
|
||||
'keep_checkpoint_max': 500,
|
||||
'train_dir': '/data/workspace/mindspore_dataset/FCN/FCN/model_new/',
|
||||
```
|
||||
|
||||
如需获取更多信息,请查看`config.py`.
|
||||
|
||||
## [生成数据步骤](#contents)
|
||||
|
||||
### 训练数据
|
||||
|
||||
- build mindrecord training data
|
||||
|
||||
```python
|
||||
sh build_data.sh
|
||||
or
|
||||
python src/data/build_seg_data.py --data_root=/home/sun/data/Mindspore/benchmark_RELEASE/dataset \
|
||||
--data_lst=/home/sun/data/Mindspore/benchmark_RELEASE/dataset/trainaug.txt \
|
||||
--dst_path=dataset/MINDRECORED_NAME.mindrecord \
|
||||
--num_shards=1 \
|
||||
--shuffle=True
|
||||
data_root: 训练数据集的总目录包含两个子目录img和cls_png,img目录下存放训练图像,cls_png目录下存放标签mask图像,
|
||||
data_lst: 存放训练样本的名称列表文档,每行一个样本。
|
||||
dst_path: 生成mindrecord数据的目标位置
|
||||
```
|
||||
|
||||
## [训练步骤](#contents)
|
||||
|
||||
### 训练
|
||||
|
||||
- running on Ascend with default parameters
|
||||
|
||||
```python
|
||||
python train.py --device_id device_id
|
||||
```
|
||||
|
||||
训练时,训练过程中的epch和step以及此时的loss和精确度会呈现在终端上:
|
||||
|
||||
```python
|
||||
epoch: * step: **, loss is ****
|
||||
...
|
||||
```
|
||||
|
||||
此模型的checkpoint会在默认路径下存储
|
||||
|
||||
## [评估步骤](#contents)
|
||||
|
||||
### 评估
|
||||
|
||||
- 在Ascend上使用PASCAL VOC 2012 验证集进行评估
|
||||
|
||||
在使用命令运行前,请检查用于评估的checkpoint的路径。请设置路径为到checkpoint的绝对路径,如 "/data/workspace/mindspore_dataset/FCN/FCN/model_new/FCN8s-500_82.ckpt"。
|
||||
|
||||
```python
|
||||
python eval.py
|
||||
```
|
||||
|
||||
以上的python命令会在终端上运行,你可以在终端上查看此次评估的结果。测试集的精确度会以如下方式呈现:
|
||||
|
||||
```python
|
||||
mean IoU 0.6467
|
||||
```
|
||||
|
||||
# [模型介绍](#contents)
|
||||
|
||||
## [性能](#contents)
|
||||
|
||||
### 评估性能
|
||||
|
||||
#### FCN8s on PASCAL VOC 2012
|
||||
|
||||
| Parameters | Ascend
|
||||
| -------------------------- | -----------------------------------------------------------
|
||||
| Model Version | FCN-8s
|
||||
| Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G
|
||||
| uploaded Date | 12/30/2020 (month/day/year)
|
||||
| MindSpore Version | 1.1.0-alpha
|
||||
| Dataset | PASCAL VOC 2012 and SBD
|
||||
| Training Parameters | epoch=500, steps=330, batch_size = 32, lr=0.015
|
||||
| Optimizer | Momentum
|
||||
| Loss Function | Softmax Cross Entropy
|
||||
| outputs | probability
|
||||
| Loss | 0.038
|
||||
| Speed | 1pc: 564.652 ms/step;
|
||||
| Scripts | [FCN script](https://gitee.com/mindspore/mindspore/tree/r1.0/model_zoo/official/cv/FCN)
|
||||
|
||||
### Inference Performance
|
||||
|
||||
#### FCN8s on PASCAL VOC
|
||||
|
||||
| Parameters | Ascend
|
||||
| ------------------- | ---------------------------
|
||||
| Model Version | FCN-8s
|
||||
| Resource | Ascend 910
|
||||
| Uploaded Date | 10/29/2020 (month/day/year)
|
||||
| MindSpore Version | 1.1.0-alpha
|
||||
| Dataset | PASCAL VOC 2012
|
||||
| batch_size | 16
|
||||
| outputs | probability
|
||||
| mean IoU | 64.67
|
||||
|
||||
## [如何使用](#contents)
|
||||
|
||||
### 教程
|
||||
|
||||
如果你需要在不同硬件平台(如GPU,Ascend 910 或者 Ascend 310)使用训练好的模型,你可以参考这个 [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html)。以下是一个简单例子的步骤介绍:
|
||||
|
||||
- Running on Ascend
|
||||
|
||||
```
|
||||
# Set context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False)
|
||||
context.set_auto_parallel_context(device_num=device_num,parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||
init()
|
||||
|
||||
# Load dataset
|
||||
dataset = data_generator.SegDataset(image_mean=cfg.image_mean,
|
||||
image_std=cfg.image_std,
|
||||
data_file=cfg.data_file,
|
||||
batch_size=cfg.batch_size,
|
||||
crop_size=cfg.crop_size,
|
||||
max_scale=cfg.max_scale,
|
||||
min_scale=cfg.min_scale,
|
||||
ignore_label=cfg.ignore_label,
|
||||
num_classes=cfg.num_classes,
|
||||
num_readers=2,
|
||||
num_parallel_calls=4,
|
||||
shard_id=args.rank,
|
||||
shard_num=args.group_size)
|
||||
dataset = dataset.get_dataset(repeat=1)
|
||||
|
||||
# Define model
|
||||
net = FCN8s(n_class=cfg.num_classes)
|
||||
loss_ = loss.SoftmaxCrossEntropyLoss(cfg.num_classes, cfg.ignore_label)
|
||||
|
||||
# optimizer
|
||||
iters_per_epoch = dataset.get_dataset_size()
|
||||
total_train_steps = iters_per_epoch * cfg.train_epochs
|
||||
|
||||
lr_scheduler = CosineAnnealingLR(cfg.base_lr,
|
||||
cfg.train_epochs,
|
||||
iters_per_epoch,
|
||||
cfg.train_epochs,
|
||||
warmup_epochs=0,
|
||||
eta_min=0)
|
||||
lr = Tensor(lr_scheduler.get_lr())
|
||||
|
||||
# loss scale
|
||||
manager_loss_scale = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
|
||||
|
||||
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001,
|
||||
loss_scale=cfg.loss_scale)
|
||||
|
||||
model = Model(net, loss_fn=loss_, loss_scale_manager=manager_loss_scale, optimizer=optimizer, amp_level="O3")
|
||||
|
||||
# callback for saving ckpts
|
||||
time_cb = TimeMonitor(data_size=iters_per_epoch)
|
||||
loss_cb = LossMonitor()
|
||||
cbs = [time_cb, loss_cb]
|
||||
|
||||
if args.rank == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=cfg.model, directory=cfg.train_dir, config=config_ck)
|
||||
cbs.append(ckpoint_cb)
|
||||
|
||||
model.train(cfg.train_epochs, dataset, callbacks=cbs)
|
||||
|
||||
# [随机事件介绍](#contents)
|
||||
|
||||
我们在train.py中设置了随机种子
|
||||
|
||||
# [ModelZoo 主页](#contents)
|
||||
|
||||
请查看官方网站 [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
|
|
@ -0,0 +1,213 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""eval FCN8s."""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.nets.FCN8s import FCN8s
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser('mindspore FCN8s eval')
|
||||
|
||||
# val data
|
||||
parser.add_argument('--data_root', type=str, default='../VOCdevkit/VOC2012/', help='root path of val data')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='batch size')
|
||||
parser.add_argument('--data_lst', type=str, default='../VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt',
|
||||
help='list of val data')
|
||||
parser.add_argument('--crop_size', type=int, default=512, help='crop size')
|
||||
parser.add_argument('--image_mean', type=list, default=[103.53, 116.28, 123.675], help='image mean')
|
||||
parser.add_argument('--image_std', type=list, default=[57.375, 57.120, 58.395], help='image std')
|
||||
parser.add_argument('--scales', type=float, default=[1.0], action='append', help='scales of evaluation')
|
||||
parser.add_argument('--flip', type=bool, default=False, help='perform left-right flip')
|
||||
parser.add_argument('--ignore_label', type=int, default=255, help='ignore label')
|
||||
parser.add_argument('--num_classes', type=int, default=21, help='number of classes')
|
||||
|
||||
# model
|
||||
parser.add_argument('--model', type=str, default='FCN8s', help='select model')
|
||||
parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn')
|
||||
parser.add_argument('--ckpt_path', type=str, default='model_new/FCN8s-500_82.ckpt', help='model to evaluate')
|
||||
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
return args
|
||||
|
||||
|
||||
def cal_hist(a, b, n):
|
||||
k = (a >= 0) & (a < n)
|
||||
return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n)
|
||||
|
||||
|
||||
def resize_long(img, long_size=513):
|
||||
h, w, _ = img.shape
|
||||
if h > w:
|
||||
new_h = long_size
|
||||
new_w = int(1.0 * long_size * w / h)
|
||||
else:
|
||||
new_w = long_size
|
||||
new_h = int(1.0 * long_size * h / w)
|
||||
imo = cv2.resize(img, (new_w, new_h))
|
||||
return imo
|
||||
|
||||
|
||||
class BuildEvalNetwork(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(BuildEvalNetwork, self).__init__()
|
||||
self.network = network
|
||||
self.softmax = nn.Softmax(axis=1)
|
||||
|
||||
def construct(self, input_data):
|
||||
output = self.network(input_data)
|
||||
output = self.softmax(output)
|
||||
return output
|
||||
|
||||
|
||||
def pre_process(args, img_, crop_size=512):
|
||||
# resize
|
||||
img_ = resize_long(img_, crop_size)
|
||||
resize_h, resize_w, _ = img_.shape
|
||||
|
||||
# mean, std
|
||||
image_mean = np.array(args.image_mean)
|
||||
image_std = np.array(args.image_std)
|
||||
img_ = (img_ - image_mean) / image_std
|
||||
|
||||
# pad to crop_size
|
||||
pad_h = crop_size - img_.shape[0]
|
||||
pad_w = crop_size - img_.shape[1]
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
|
||||
|
||||
# hwc to chw
|
||||
img_ = img_.transpose((2, 0, 1))
|
||||
return img_, resize_h, resize_w
|
||||
|
||||
|
||||
def eval_batch(args, eval_net, img_lst, crop_size=512, flip=True):
|
||||
result_lst = []
|
||||
batch_size = len(img_lst)
|
||||
batch_img = np.zeros((args.batch_size, 3, crop_size, crop_size), dtype=np.float32)
|
||||
resize_hw = []
|
||||
for l in range(batch_size):
|
||||
img_ = img_lst[l]
|
||||
img_, resize_h, resize_w = pre_process(args, img_, crop_size)
|
||||
batch_img[l] = img_
|
||||
resize_hw.append([resize_h, resize_w])
|
||||
|
||||
batch_img = np.ascontiguousarray(batch_img)
|
||||
net_out = eval_net(Tensor(batch_img, mstype.float32))
|
||||
net_out = net_out.asnumpy()
|
||||
|
||||
if flip:
|
||||
batch_img = batch_img[:, :, :, ::-1]
|
||||
net_out_flip = eval_net(Tensor(batch_img, mstype.float32))
|
||||
net_out += net_out_flip.asnumpy()[:, :, :, ::-1]
|
||||
|
||||
for bs in range(batch_size):
|
||||
probs_ = net_out[bs][:, :resize_hw[bs][0], :resize_hw[bs][1]].transpose((1, 2, 0))
|
||||
ori_h, ori_w = img_lst[bs].shape[0], img_lst[bs].shape[1]
|
||||
probs_ = cv2.resize(probs_, (ori_w, ori_h))
|
||||
result_lst.append(probs_)
|
||||
|
||||
return result_lst
|
||||
|
||||
|
||||
def eval_batch_scales(args, eval_net, img_lst, scales,
|
||||
base_crop_size=512, flip=True):
|
||||
sizes_ = [int((base_crop_size - 1) * sc) + 1 for sc in scales]
|
||||
probs_lst = eval_batch(args, eval_net, img_lst, crop_size=sizes_[0], flip=flip)
|
||||
print(sizes_)
|
||||
for crop_size_ in sizes_[1:]:
|
||||
probs_lst_tmp = eval_batch(args, eval_net, img_lst, crop_size=crop_size_, flip=flip)
|
||||
for pl, _ in enumerate(probs_lst):
|
||||
probs_lst[pl] += probs_lst_tmp[pl]
|
||||
|
||||
result_msk = []
|
||||
for i in probs_lst:
|
||||
result_msk.append(i.argmax(axis=2))
|
||||
return result_msk
|
||||
|
||||
|
||||
def net_eval():
|
||||
args = parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id,
|
||||
save_graphs=False)
|
||||
|
||||
# data list
|
||||
with open(args.data_lst) as f:
|
||||
img_lst = f.readlines()
|
||||
|
||||
net = FCN8s(n_class=args.num_classes)
|
||||
|
||||
# load model
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
# evaluate
|
||||
hist = np.zeros((args.num_classes, args.num_classes))
|
||||
batch_img_lst = []
|
||||
batch_msk_lst = []
|
||||
bi = 0
|
||||
image_num = 0
|
||||
for i, line in enumerate(img_lst):
|
||||
|
||||
img_name = line.strip('\n')
|
||||
data_root = args.data_root
|
||||
img_path = data_root + '/JPEGImages/' + str(img_name) + '.jpg'
|
||||
msk_path = data_root + '/SegmentationClass/' + str(img_name) + '.png'
|
||||
|
||||
img_ = np.array(Image.open(img_path), dtype=np.uint8)
|
||||
msk_ = np.array(Image.open(msk_path), dtype=np.uint8)
|
||||
|
||||
batch_img_lst.append(img_)
|
||||
batch_msk_lst.append(msk_)
|
||||
bi += 1
|
||||
if bi == args.batch_size:
|
||||
batch_res = eval_batch_scales(args, net, batch_img_lst, scales=args.scales,
|
||||
base_crop_size=args.crop_size, flip=args.flip)
|
||||
for mi in range(args.batch_size):
|
||||
hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes)
|
||||
|
||||
bi = 0
|
||||
batch_img_lst = []
|
||||
batch_msk_lst = []
|
||||
print('processed {} images'.format(i+1))
|
||||
image_num = i
|
||||
|
||||
if bi > 0:
|
||||
batch_res = eval_batch_scales(args, net, batch_img_lst, scales=args.scales,
|
||||
base_crop_size=args.crop_size, flip=args.flip)
|
||||
for mi in range(bi):
|
||||
hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes)
|
||||
print('processed {} images'.format(image_num + 1))
|
||||
|
||||
print(hist)
|
||||
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
|
||||
print('per-class IoU', iu)
|
||||
print('mean IoU', np.nanmean(iu))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net_eval()
|
|
@ -0,0 +1,22 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=0
|
||||
python src/data/build_seg_data.py --data_root=/home/sun/data/Mindspore/benchmark_RELEASE/dataset \
|
||||
--data_lst=/home/sun/data/Mindspore/benchmark_RELEASE/dataset/trainaug.txt \
|
||||
--dst_path=dataset/MINDRECORED_NAME.mindrecord \
|
||||
--num_shards=1 \
|
||||
--shuffle=True
|
|
@ -0,0 +1,43 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_distribute_eval.sh DEVICE_NUM RANK_TABLE_FILE DATASET CKPT_PATH"
|
||||
echo "for example: sh run_eval.sh [RANK_TABLE_FILE] /path/to/dataset /path/to/ckpt device_id"
|
||||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
|
||||
export DATA_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
DEVICE_ID=$3
|
||||
|
||||
rm -rf eval
|
||||
mkdir ./eval
|
||||
cp ./*.py ./eval
|
||||
cp -r ./src ./eval
|
||||
cd ./eval || exit
|
||||
echo "start testing"
|
||||
env > env.log
|
||||
python eval.py \
|
||||
--device_id=$DEVICE_ID \
|
||||
--data_path=$DATA_PATH \
|
||||
--ckpt_path=$CKPT_PATH #> log.txt 2>&1 &
|
||||
|
||||
cd ../
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_train.sh [device_num][RANK_TABLE_FILE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $2 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=$1
|
||||
export RANK_SIZE=$1
|
||||
RANK_TABLE_FILE=$(realpath $2)
|
||||
export RANK_TABLE_FILE
|
||||
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
for((i=0; i<$1; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./train.py ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env.log
|
||||
python train.py --device_id=$i > log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
|
||||
FCN8s_VOC2012_cfg = edict({
|
||||
# dataset
|
||||
'data_file': '/data/workspace/mindspore_dataset/FCN/FCN/dataset/MINDRECORED_NAME.mindrecord',
|
||||
'batch_size': 32,
|
||||
'crop_size': 512,
|
||||
'image_mean': [103.53, 116.28, 123.675],
|
||||
'image_std': [57.375, 57.120, 58.395],
|
||||
'min_scale': 0.5,
|
||||
'max_scale': 2.0,
|
||||
'ignore_label': 255,
|
||||
'num_classes': 21,
|
||||
|
||||
# optimizer
|
||||
'train_epochs': 500,
|
||||
'base_lr': 0.015,
|
||||
'loss_scale': 1024.0,
|
||||
|
||||
# model
|
||||
'model': 'FCN8s',
|
||||
'ckpt_vgg16': '/data/workspace/mindspore_dataset/FCN/FCN/model/0-150_5004.ckpt',
|
||||
'ckpt_pre_trained': '/data/workspace/mindspore_dataset/FCN/FCN/model_new/FCN8s-500_82.ckpt',
|
||||
|
||||
# train
|
||||
'save_steps': 330,
|
||||
'keep_checkpoint_max': 500,
|
||||
'train_dir': '/data/workspace/mindspore_dataset/FCN/FCN/model_new/',
|
||||
})
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
|
||||
seg_schema = {"file_name": {"type": "string"}, "label": {"type": "bytes"}, "data": {"type": "bytes"}}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser('mindrecord')
|
||||
|
||||
parser.add_argument('--data_root', type=str, default='', help='root path of data')
|
||||
parser.add_argument('--data_lst', type=str, default='', help='list of data')
|
||||
parser.add_argument('--dst_path', type=str, default='', help='save path of mindrecords')
|
||||
parser.add_argument('--num_shards', type=int, default=8, help='number of shards')
|
||||
parser.add_argument('--shuffle', type=bool, default=True, help='shuffle or not')
|
||||
|
||||
parser_args, _ = parser.parse_known_args()
|
||||
return parser_args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
datas = []
|
||||
with open(args.data_lst) as f:
|
||||
lines = f.readlines()
|
||||
if args.shuffle:
|
||||
np.random.shuffle(lines)
|
||||
|
||||
dst_dir = '/'.join(args.dst_path.split('/')[:-1])
|
||||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
|
||||
print('number of samples:', len(lines))
|
||||
writer = FileWriter(file_name=args.dst_path, shard_num=args.num_shards)
|
||||
writer.add_schema(seg_schema, "seg_schema")
|
||||
cnt = 0
|
||||
|
||||
for l in lines:
|
||||
img_name = l.strip('\n')
|
||||
|
||||
img_path = 'img/' + str(img_name) + '.jpg'
|
||||
label_path = 'cls_png/' + str(img_name) + '.png'
|
||||
|
||||
sample_ = {"file_name": img_path.split('/')[-1]}
|
||||
|
||||
with open(os.path.join(args.data_root, img_path), 'rb') as f:
|
||||
sample_['data'] = f.read()
|
||||
with open(os.path.join(args.data_root, label_path), 'rb') as f:
|
||||
sample_['label'] = f.read()
|
||||
datas.append(sample_)
|
||||
cnt += 1
|
||||
if cnt % 1000 == 0:
|
||||
writer.write_raw_data(datas)
|
||||
print('number of samples written:', cnt)
|
||||
datas = []
|
||||
|
||||
if datas:
|
||||
writer.write_raw_data(datas)
|
||||
writer.commit()
|
||||
print('number of samples written:', cnt)
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import mindspore.dataset as de
|
||||
cv2.setNumThreads(0)
|
||||
|
||||
|
||||
class SegDataset:
|
||||
def __init__(self,
|
||||
image_mean,
|
||||
image_std,
|
||||
data_file='',
|
||||
batch_size=32,
|
||||
crop_size=512,
|
||||
max_scale=2.0,
|
||||
min_scale=0.5,
|
||||
ignore_label=255,
|
||||
num_classes=21,
|
||||
num_readers=2,
|
||||
num_parallel_calls=4,
|
||||
shard_id=None,
|
||||
shard_num=None):
|
||||
|
||||
self.data_file = data_file
|
||||
self.batch_size = batch_size
|
||||
self.crop_size = crop_size
|
||||
self.image_mean = np.array(image_mean, dtype=np.float32)
|
||||
self.image_std = np.array(image_std, dtype=np.float32)
|
||||
self.max_scale = max_scale
|
||||
self.min_scale = min_scale
|
||||
self.ignore_label = ignore_label
|
||||
self.num_classes = num_classes
|
||||
self.num_readers = num_readers
|
||||
self.num_parallel_calls = num_parallel_calls
|
||||
self.shard_id = shard_id
|
||||
self.shard_num = shard_num
|
||||
assert max_scale > min_scale
|
||||
|
||||
def preprocess_(self, image, label):
|
||||
# bgr image
|
||||
image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
sc = np.random.uniform(self.min_scale, self.max_scale)
|
||||
new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
|
||||
image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
|
||||
label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
image_out = (image_out - self.image_mean) / self.image_std
|
||||
h_, w_ = max(new_h, self.crop_size), max(new_w, self.crop_size)
|
||||
pad_h, pad_w = h_ - new_h, w_ - new_w
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
|
||||
label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)
|
||||
offset_h = np.random.randint(0, h_ - self.crop_size + 1)
|
||||
offset_w = np.random.randint(0, w_ - self.crop_size + 1)
|
||||
image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]
|
||||
label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]
|
||||
|
||||
if np.random.uniform(0.0, 1.0) > 0.5:
|
||||
image_out = image_out[:, ::-1, :]
|
||||
label_out = label_out[:, ::-1]
|
||||
|
||||
image_out = image_out.transpose((2, 0, 1))
|
||||
image_out = image_out.copy()
|
||||
label_out = label_out.copy()
|
||||
return image_out, label_out
|
||||
|
||||
def get_dataset(self, repeat=1):
|
||||
data_set = de.MindDataset(dataset_file=self.data_file, columns_list=["data", "label"],
|
||||
shuffle=True, num_parallel_workers=self.num_readers,
|
||||
num_shards=self.shard_num, shard_id=self.shard_id)
|
||||
transforms_list = self.preprocess_
|
||||
data_set = data_set.map(operations=transforms_list, input_columns=["data", "label"],
|
||||
output_columns=["data", "label"],
|
||||
num_parallel_workers=self.num_parallel_calls)
|
||||
data_set = data_set.shuffle(buffer_size=self.batch_size * 10)
|
||||
data_set = data_set.batch(self.batch_size, drop_remainder=True)
|
||||
data_set = data_set.repeat(repeat)
|
||||
return data_set
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyLoss(nn.Cell):
|
||||
def __init__(self, num_cls=21, ignore_label=255):
|
||||
super(SoftmaxCrossEntropyLoss, self).__init__()
|
||||
self.one_hot = P.OneHot(axis=-1)
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.num_cls = num_cls
|
||||
self.ignore_label = ignore_label
|
||||
self.mul = P.Mul()
|
||||
self.sum = P.ReduceSum(False)
|
||||
self.div = P.RealDiv()
|
||||
self.transpose = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, logits, labels):
|
||||
labels_int = self.cast(labels, mstype.int32)
|
||||
labels_int = self.reshape(labels_int, (-1,))
|
||||
logits_ = self.transpose(logits, (0, 2, 3, 1))
|
||||
logits_ = self.reshape(logits_, (-1, self.num_cls))
|
||||
weights = self.not_equal(labels_int, self.ignore_label)
|
||||
weights = self.cast(weights, mstype.float32)
|
||||
one_hot_labels = self.one_hot(labels_int, self.num_cls, self.on_value, self.off_value)
|
||||
logits_ = self.cast(logits_, mstype.float32)
|
||||
loss = self.ce(logits_, one_hot_labels)
|
||||
loss = self.mul(weights, loss)
|
||||
loss = self.div(self.sum(loss), self.sum(weights))
|
||||
return loss
|
|
@ -0,0 +1,206 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class FCN8s(nn.Cell):
|
||||
def __init__(self, n_class):
|
||||
super().__init__()
|
||||
self.n_class = n_class
|
||||
self.conv1 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=3,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=64,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv2 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=64,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=128,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv3 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=128,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv4 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=256,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv5 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv6 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=4096,
|
||||
kernel_size=7,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(4096),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.conv7 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=4096,
|
||||
out_channels=4096,
|
||||
kernel_size=1,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(4096),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.score_fr = nn.Conv2d(in_channels=4096,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=1,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.score_pool4 = nn.Conv2d(in_channels=512,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=1,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.score_pool3 = nn.Conv2d(in_channels=256,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=1,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=16,
|
||||
stride=8,
|
||||
weight_init='xavier_uniform')
|
||||
self.shape = P.Shape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.conv1(x)
|
||||
p1 = self.pool1(x1)
|
||||
x2 = self.conv2(p1)
|
||||
p2 = self.pool2(x2)
|
||||
x3 = self.conv3(p2)
|
||||
p3 = self.pool3(x3)
|
||||
x4 = self.conv4(p3)
|
||||
p4 = self.pool4(x4)
|
||||
x5 = self.conv5(p4)
|
||||
p5 = self.pool5(x5)
|
||||
|
||||
x6 = self.conv6(p5)
|
||||
x7 = self.conv7(x6)
|
||||
|
||||
sf = self.score_fr(x7)
|
||||
u2 = self.upscore2(sf)
|
||||
|
||||
s4 = self.score_pool4(p4)
|
||||
f4 = s4 + u2
|
||||
u4 = self.upscore_pool4(f4)
|
||||
|
||||
s3 = self.score_pool3(p3)
|
||||
f3 = s3 + u4
|
||||
out = self.upscore8(f3)
|
||||
|
||||
return out
|
|
@ -0,0 +1,656 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
learning rate scheduler
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["LambdaLR", "MultiplicativeLR", "StepLR", "MultiStepLR", "ExponentialLR",
|
||||
"CosineAnnealingLR", "CyclicLR", "CosineAnnealingWarmRestarts", "OneCycleLR"]
|
||||
|
||||
class _WarmUp():
|
||||
def __init__(self, warmup_init_lr):
|
||||
self.warmup_init_lr = warmup_init_lr
|
||||
|
||||
def get_lr(self):
|
||||
# Get learning rate during warmup
|
||||
raise NotImplementedError
|
||||
|
||||
class _LinearWarmUp(_WarmUp):
|
||||
"""
|
||||
linear warmup function
|
||||
"""
|
||||
def __init__(self, lr, warmup_epochs, steps_per_epoch, warmup_init_lr=0):
|
||||
self.base_lr = lr
|
||||
self.warmup_init_lr = warmup_init_lr
|
||||
self.warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
super(_LinearWarmUp, self).__init__(warmup_init_lr)
|
||||
|
||||
def get_warmup_steps(self):
|
||||
return self.warmup_steps
|
||||
|
||||
def get_lr(self, current_step):
|
||||
lr_inc = (float(self.base_lr) - float(self.warmup_init_lr)) / float(self.warmup_steps)
|
||||
lr = float(self.warmup_init_lr) + lr_inc * current_step
|
||||
return lr
|
||||
|
||||
class _ConstWarmUp(_WarmUp):
|
||||
|
||||
def get_lr(self):
|
||||
return self.warmup_init_lr
|
||||
|
||||
class _LRScheduler():
|
||||
|
||||
def __init__(self, lr, max_epoch, steps_per_epoch):
|
||||
self.base_lr = lr
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.total_steps = int(max_epoch * steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
# Compute learning rate using chainable form of the scheduler
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LambdaLR(_LRScheduler):
|
||||
"""Sets the learning rate to the initial lr times a given function.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate which is the
|
||||
lower boundary in the cycle.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
lr_lambda (function or list): A function which computes a multiplicative
|
||||
factor given an integer parameter epoch.
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
Example:
|
||||
>>> # Assuming optimizer has two groups.
|
||||
>>> lambda1 = lambda epoch: epoch // 30
|
||||
>>> scheduler = LambdaLR(lr=0.1, lr_lambda=lambda1, steps_per_epoch=5000,
|
||||
>>> max_epoch=90, warmup_epochs=0)
|
||||
>>> lr = scheduler.get_lr()
|
||||
"""
|
||||
|
||||
def __init__(self, lr, lr_lambda, steps_per_epoch, max_epoch, warmup_epochs=0):
|
||||
self.lr_lambda = lr_lambda
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(LambdaLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
cur_ep = i // self.steps_per_epoch
|
||||
lr = self.base_lr * self.lr_lambda(cur_ep)
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class MultiplicativeLR(_LRScheduler):
|
||||
"""Multiply the learning rate by the factor given
|
||||
in the specified function.
|
||||
|
||||
Args:
|
||||
lr_lambda (function or list): A function which computes a multiplicative
|
||||
factor given an integer parameter epoch,.
|
||||
|
||||
Example:
|
||||
>>> lmbda = lambda epoch: 0.95
|
||||
>>> scheduler = MultiplicativeLR(lr=0.1, lr_lambda=lambda1, steps_per_epoch=5000,
|
||||
>>> max_epoch=90, warmup_epochs=0)
|
||||
>>> lr = scheduler.get_lr()
|
||||
"""
|
||||
def __init__(self, lr, lr_lambda, steps_per_epoch, max_epoch, warmup_epochs=0):
|
||||
self.lr_lambda = lr_lambda
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(MultiplicativeLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
current_lr = self.base_lr
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
cur_ep = i // self.steps_per_epoch
|
||||
if i % self.steps_per_epoch == 0 and cur_ep > 0:
|
||||
current_lr = current_lr * self.lr_lambda(cur_ep)
|
||||
|
||||
lr = current_lr
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class StepLR(_LRScheduler):
|
||||
"""Decays the learning rate by gamma every epoch_size epochs.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate which is the
|
||||
lower boundary in the cycle.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
epoch_size (int): Period of learning rate decay.
|
||||
gamma (float): Multiplicative factor of learning rate decay.
|
||||
Default: 0.1.
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
|
||||
Example:
|
||||
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
||||
>>> # lr = 0.05 if epoch < 30
|
||||
>>> # lr = 0.005 if 30 <= epoch < 60
|
||||
>>> # lr = 0.0005 if 60 <= epoch < 90
|
||||
>>> # ...
|
||||
>>> scheduler = StepLR(lr=0.1, epoch_size=30, gamma=0.1, steps_per_epoch=5000,
|
||||
>>> max_epoch=90, warmup_epochs=0)
|
||||
>>> lr = scheduler.get_lr()
|
||||
"""
|
||||
|
||||
def __init__(self, lr, epoch_size, gamma, steps_per_epoch, max_epoch, warmup_epochs=0):
|
||||
self.epoch_size = epoch_size
|
||||
self.gamma = gamma
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(StepLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
cur_ep = i // self.steps_per_epoch
|
||||
lr = self.base_lr * self.gamma**(cur_ep // self.epoch_size)
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class MultiStepLR(_LRScheduler):
|
||||
"""Decays the learning rate by gamma once the number of epoch reaches one
|
||||
of the milestones.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate which is the
|
||||
lower boundary in the cycle.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
milestones (list): List of epoch indices. Must be increasing.
|
||||
gamma (float): Multiplicative factor of learning rate decay.
|
||||
Default: 0.1.
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
|
||||
Example:
|
||||
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
||||
>>> # lr = 0.05 if epoch < 30
|
||||
>>> # lr = 0.005 if 30 <= epoch < 80
|
||||
>>> # lr = 0.0005 if epoch >= 80
|
||||
>>> scheduler = MultiStepLR(lr=0.1, milestones=[30,80], gamma=0.1, steps_per_epoch=5000,
|
||||
>>> max_epoch=90, warmup_epochs=0)
|
||||
>>> lr = scheduler.get_lr()
|
||||
"""
|
||||
|
||||
def __init__(self, lr, milestones, gamma, steps_per_epoch, max_epoch, warmup_epochs=0):
|
||||
self.milestones = Counter(milestones)
|
||||
self.gamma = gamma
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(MultiStepLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
current_lr = self.base_lr
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
cur_ep = i // self.steps_per_epoch
|
||||
if i % self.steps_per_epoch == 0 and cur_ep in self.milestones:
|
||||
current_lr = current_lr * self.gamma
|
||||
lr = current_lr
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class ExponentialLR(_LRScheduler):
|
||||
"""Decays the learning rate of each parameter group by gamma every epoch.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate which is the
|
||||
lower boundary in the cycle.
|
||||
gamma (float): Multiplicative factor of learning rate decay.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
"""
|
||||
|
||||
def __init__(self, lr, gamma, steps_per_epoch, max_epoch, warmup_epochs=0):
|
||||
self.gamma = gamma
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(ExponentialLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
current_lr = self.base_lr
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
if i % self.steps_per_epoch == 0 and i > 0:
|
||||
current_lr = current_lr * self.gamma
|
||||
lr = current_lr
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class CosineAnnealingLR(_LRScheduler):
|
||||
r"""Set the learning rate using a cosine annealing schedule, where
|
||||
:math:`\eta_{max}` is set to the initial lr and :math:`T_{cur}` is the
|
||||
number of epochs since the last restart in SGDR:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
|
||||
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
|
||||
& T_{cur} \neq (2k+1)T_{max}; \\
|
||||
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
|
||||
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
|
||||
& T_{cur} = (2k+1)T_{max}.
|
||||
\end{aligned}
|
||||
|
||||
It has been proposed in
|
||||
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
|
||||
implements the cosine annealing part of SGDR, and not the restarts.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate which is the
|
||||
lower boundary in the cycle.
|
||||
T_max (int): Maximum number of iterations.
|
||||
eta_min (float): Minimum learning rate. Default: 0.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
|
||||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||
https://arxiv.org/abs/1608.03983
|
||||
"""
|
||||
|
||||
def __init__(self, lr, T_max, steps_per_epoch, max_epoch, warmup_epochs=0, eta_min=0):
|
||||
self.T_max = T_max
|
||||
self.eta_min = eta_min
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(CosineAnnealingLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
current_lr = self.base_lr
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
cur_ep = i // self.steps_per_epoch
|
||||
if i % self.steps_per_epoch == 0 and i > 0:
|
||||
current_lr = self.eta_min + \
|
||||
(self.base_lr - self.eta_min) * (1. + math.cos(math.pi*cur_ep / self.T_max)) / 2
|
||||
|
||||
lr = current_lr
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class CyclicLR(_LRScheduler):
|
||||
r"""Sets the learning rate according to cyclical learning rate policy (CLR).
|
||||
The policy cycles the learning rate between two boundaries with a constant
|
||||
frequency, as detailed in the paper `Cyclical Learning Rates for Training
|
||||
Neural Networks`_. The distance between the two boundaries can be scaled on
|
||||
a per-iteration or per-cycle basis.
|
||||
|
||||
Cyclical learning rate policy changes the learning rate after every batch.
|
||||
|
||||
This class has three built-in policies, as put forth in the paper:
|
||||
|
||||
* "triangular": A basic triangular cycle without amplitude scaling.
|
||||
* "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
|
||||
* "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
|
||||
at each cycle iteration.
|
||||
|
||||
This implementation was adapted from the github repo: `bckenstler/CLR`_
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate which is the
|
||||
lower boundary in the cycle.
|
||||
max_lr (float): Upper learning rate boundaries in the cycle.
|
||||
Functionally, it defines the cycle amplitude (max_lr - base_lr).
|
||||
The lr at any cycle is the sum of base_lr and some scaling
|
||||
of the amplitude; therefore max_lr may not actually be reached
|
||||
depending on scaling function.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
step_size_up (int): Number of training iterations in the
|
||||
increasing half of a cycle. Default: 2000
|
||||
step_size_down (int): Number of training iterations in the
|
||||
decreasing half of a cycle. If step_size_down is None,
|
||||
it is set to step_size_up. Default: None
|
||||
mode (str): One of {triangular, triangular2, exp_range}.
|
||||
Values correspond to policies detailed above.
|
||||
If scale_fn is not None, this argument is ignored.
|
||||
Default: 'triangular'
|
||||
gamma (float): Constant in 'exp_range' scaling function:
|
||||
gamma**(cycle iterations)
|
||||
Default: 1.0
|
||||
scale_fn (function): Custom scaling policy defined by a single
|
||||
argument lambda function, where
|
||||
0 <= scale_fn(x) <= 1 for all x >= 0.
|
||||
If specified, then 'mode' is ignored.
|
||||
Default: None
|
||||
scale_mode (str): {'cycle', 'iterations'}.
|
||||
Defines whether scale_fn is evaluated on
|
||||
cycle number or cycle iterations (training
|
||||
iterations since start of cycle).
|
||||
Default: 'cycle'
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
|
||||
.. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
|
||||
.. _bckenstler/CLR: https://github.com/bckenstler/CLR
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
lr,
|
||||
max_lr,
|
||||
steps_per_epoch,
|
||||
max_epoch,
|
||||
step_size_up=2000,
|
||||
step_size_down=None,
|
||||
mode='triangular',
|
||||
gamma=1.,
|
||||
scale_fn=None,
|
||||
scale_mode='cycle',
|
||||
warmup_epochs=0):
|
||||
|
||||
self.max_lr = max_lr
|
||||
|
||||
step_size_up = float(step_size_up)
|
||||
step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
|
||||
self.total_size = step_size_up + step_size_down
|
||||
self.step_ratio = step_size_up / self.total_size
|
||||
|
||||
if mode not in ['triangular', 'triangular2', 'exp_range'] \
|
||||
and scale_fn is None:
|
||||
raise ValueError('mode is invalid and scale_fn is None')
|
||||
|
||||
self.mode = mode
|
||||
self.gamma = gamma
|
||||
|
||||
if scale_fn is None:
|
||||
if self.mode == 'triangular':
|
||||
self.scale_fn = self._triangular_scale_fn
|
||||
self.scale_mode = 'cycle'
|
||||
elif self.mode == 'triangular2':
|
||||
self.scale_fn = self._triangular2_scale_fn
|
||||
self.scale_mode = 'cycle'
|
||||
elif self.mode == 'exp_range':
|
||||
self.scale_fn = self._exp_range_scale_fn
|
||||
self.scale_mode = 'iterations'
|
||||
else:
|
||||
self.scale_fn = scale_fn
|
||||
self.scale_mode = scale_mode
|
||||
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(CyclicLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def _triangular_scale_fn(self, x):
|
||||
return 1.
|
||||
|
||||
def _triangular2_scale_fn(self, x):
|
||||
return 1 / (2. ** (x - 1))
|
||||
|
||||
def _exp_range_scale_fn(self, x):
|
||||
return self.gamma**(x)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
# Calculates the learning rate at batch index.
|
||||
cycle = math.floor(1 + i / self.total_size)
|
||||
x = 1. + i / self.total_size - cycle
|
||||
if x <= self.step_ratio:
|
||||
scale_factor = x / self.step_ratio
|
||||
else:
|
||||
scale_factor = (x - 1) / (self.step_ratio - 1)
|
||||
|
||||
base_height = (self.max_lr - self.base_lr) * scale_factor
|
||||
if self.scale_mode == 'cycle':
|
||||
lr = self.base_lr + base_height * self.scale_fn(cycle)
|
||||
else:
|
||||
lr = self.base_lr + base_height * self.scale_fn(i)
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class CosineAnnealingWarmRestarts(_LRScheduler):
|
||||
r"""Set the learning rate using a cosine annealing schedule, where
|
||||
:math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` is the
|
||||
number of epochs since the last restart and :math:`T_{i}` is the number
|
||||
of epochs between two warm restarts in SGDR:
|
||||
|
||||
.. math::
|
||||
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
||||
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
|
||||
|
||||
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
|
||||
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
|
||||
|
||||
It has been proposed in
|
||||
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
T_0 (int): Number of iterations for the first restart.
|
||||
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
|
||||
eta_min (float, optional): Minimum learning rate. Default: 0.
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
|
||||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||
https://arxiv.org/abs/1608.03983
|
||||
"""
|
||||
|
||||
def __init__(self, lr, steps_per_epoch, max_epoch, T_0, T_mult=1, eta_min=0, warmup_epochs=0):
|
||||
if T_0 <= 0 or not isinstance(T_0, int):
|
||||
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
|
||||
if T_mult < 1 or not isinstance(T_mult, int):
|
||||
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
|
||||
self.T_0 = T_0
|
||||
self.T_i = T_0
|
||||
self.T_mult = T_mult
|
||||
self.eta_min = eta_min
|
||||
self.T_cur = 0
|
||||
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(CosineAnnealingWarmRestarts, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
if i % self.steps_per_epoch == 0 and i > 0:
|
||||
self.T_cur += 1
|
||||
if self.T_cur >= self.T_i:
|
||||
self.T_cur = self.T_cur - self.T_i
|
||||
self.T_i = self.T_i * self.T_mult
|
||||
|
||||
lr = self.eta_min + (self.base_lr - self.eta_min) * \
|
||||
(1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class OneCycleLR(_LRScheduler):
|
||||
r"""Sets the learning rate of each parameter group according to the
|
||||
1cycle learning rate policy. The 1cycle policy anneals the learning
|
||||
rate from an initial learning rate to some maximum learning rate and then
|
||||
from that maximum learning rate to some minimum learning rate much lower
|
||||
than the initial learning rate.
|
||||
This policy was initially described in the paper `Super-Convergence:
|
||||
Very Fast Training of Neural Networks Using Large Learning Rates`_.
|
||||
|
||||
The 1cycle learning rate policy changes the learning rate after every batch.
|
||||
This scheduler is not chainable.
|
||||
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
||||
used along with epochs in order to infer the total number of steps in the cycle.
|
||||
max_epoch (int): The number of epochs to train for. This is used along
|
||||
with steps_per_epoch in order to infer the total number of steps in the cycle.
|
||||
pct_start (float): The percentage of the cycle (in number of steps) spent
|
||||
increasing the learning rate.
|
||||
Default: 0.3
|
||||
anneal_strategy (str): {'cos', 'linear'}
|
||||
Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
|
||||
linear annealing.
|
||||
Default: 'cos'
|
||||
div_factor (float): Determines the max learning rate via
|
||||
max_lr = lr * div_factor
|
||||
Default: 25
|
||||
final_div_factor (float): Determines the minimum learning rate via
|
||||
min_lr = lr / final_div_factor
|
||||
Default: 1e4
|
||||
warmup_epochs (int): The number of epochs to Warmup.
|
||||
Default: 0
|
||||
|
||||
|
||||
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
|
||||
https://arxiv.org/abs/1708.07120
|
||||
"""
|
||||
def __init__(self,
|
||||
lr,
|
||||
steps_per_epoch,
|
||||
max_epoch,
|
||||
pct_start=0.3,
|
||||
anneal_strategy='cos',
|
||||
div_factor=25.,
|
||||
final_div_factor=1e4,
|
||||
warmup_epochs=0):
|
||||
|
||||
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
|
||||
super(OneCycleLR, self).__init__(lr, max_epoch, steps_per_epoch)
|
||||
|
||||
self.step_size_up = float(pct_start * self.total_steps) - 1
|
||||
self.step_size_down = float(self.total_steps - self.step_size_up) - 1
|
||||
|
||||
# Validate pct_start
|
||||
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
|
||||
raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))
|
||||
|
||||
# Validate anneal_strategy
|
||||
if anneal_strategy not in ['cos', 'linear']:
|
||||
raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
|
||||
if anneal_strategy == 'cos':
|
||||
self.anneal_func = self._annealing_cos
|
||||
elif anneal_strategy == 'linear':
|
||||
self.anneal_func = self._annealing_linear
|
||||
|
||||
# Initialize learning rate variables
|
||||
self.max_lr = lr * div_factor
|
||||
self.min_lr = lr / final_div_factor
|
||||
|
||||
def _annealing_cos(self, start, end, pct):
|
||||
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
|
||||
cos_out = math.cos(math.pi * pct) + 1
|
||||
return end + (start - end) / 2.0 * cos_out
|
||||
|
||||
def _annealing_linear(self, start, end, pct):
|
||||
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
|
||||
return (end - start) * pct + start
|
||||
|
||||
def get_lr(self):
|
||||
warmup_steps = self.warmup.get_warmup_steps()
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(self.total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = self.warmup.get_lr(i+1)
|
||||
else:
|
||||
if i <= self.step_size_up:
|
||||
lr = self.anneal_func(self.base_lr, self.max_lr, i / self.step_size_up)
|
||||
|
||||
else:
|
||||
down_step_num = i - self.step_size_up
|
||||
lr = self.anneal_func(self.max_lr, self.min_lr, down_step_num / self.step_size_down)
|
||||
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
|
@ -0,0 +1,137 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train FCN8s."""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
import mindspore.nn as nn
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.callback import LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.common import set_seed
|
||||
from src.data import dataset as data_generator
|
||||
from src.loss import loss
|
||||
from src.utils.lr_scheduler import CosineAnnealingLR
|
||||
from src.nets.FCN8s import FCN8s
|
||||
from src.config import FCN8s_VOC2012_cfg
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser('mindspore FCN training')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
|
||||
args, _ = parser.parse_known_args()
|
||||
return args
|
||||
|
||||
|
||||
def train():
|
||||
args = parse_args()
|
||||
cfg = FCN8s_VOC2012_cfg
|
||||
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||
# init multicards training
|
||||
if device_num > 1:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num)
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
||||
device_target="Ascend", device_id=args.device_id)
|
||||
|
||||
# dataset
|
||||
dataset = data_generator.SegDataset(image_mean=cfg.image_mean,
|
||||
image_std=cfg.image_std,
|
||||
data_file=cfg.data_file,
|
||||
batch_size=cfg.batch_size,
|
||||
crop_size=cfg.crop_size,
|
||||
max_scale=cfg.max_scale,
|
||||
min_scale=cfg.min_scale,
|
||||
ignore_label=cfg.ignore_label,
|
||||
num_classes=cfg.num_classes,
|
||||
num_readers=2,
|
||||
num_parallel_calls=4,
|
||||
shard_id=args.rank,
|
||||
shard_num=args.group_size)
|
||||
dataset = dataset.get_dataset(repeat=1)
|
||||
|
||||
net = FCN8s(n_class=cfg.num_classes)
|
||||
loss_ = loss.SoftmaxCrossEntropyLoss(cfg.num_classes, cfg.ignore_label)
|
||||
|
||||
# load pretrained vgg16 parameters to init FCN8s
|
||||
if cfg.ckpt_vgg16:
|
||||
param_vgg = load_checkpoint(cfg.ckpt_vgg16)
|
||||
param_dict = {}
|
||||
for layer_id in range(1, 6):
|
||||
sub_layer_num = 2 if layer_id < 3 else 3
|
||||
for sub_layer_id in range(sub_layer_num):
|
||||
# conv param
|
||||
y_weight = 'conv{}.{}.weight'.format(layer_id, 3 * sub_layer_id)
|
||||
x_weight = 'vgg16_feature_extractor.conv{}_{}.0.weight'.format(layer_id, sub_layer_id + 1)
|
||||
param_dict[y_weight] = param_vgg[x_weight]
|
||||
# BatchNorm param
|
||||
y_gamma = 'conv{}.{}.gamma'.format(layer_id, 3 * sub_layer_id + 1)
|
||||
y_beta = 'conv{}.{}.beta'.format(layer_id, 3 * sub_layer_id + 1)
|
||||
x_gamma = 'vgg16_feature_extractor.conv{}_{}.1.gamma'.format(layer_id, sub_layer_id + 1)
|
||||
x_beta = 'vgg16_feature_extractor.conv{}_{}.1.beta'.format(layer_id, sub_layer_id + 1)
|
||||
param_dict[y_gamma] = param_vgg[x_gamma]
|
||||
param_dict[y_beta] = param_vgg[x_beta]
|
||||
load_param_into_net(net, param_dict)
|
||||
# load pretrained FCN8s
|
||||
elif cfg.ckpt_pre_trained:
|
||||
param_dict = load_checkpoint(cfg.ckpt_pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
# optimizer
|
||||
iters_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
lr_scheduler = CosineAnnealingLR(cfg.base_lr,
|
||||
cfg.train_epochs,
|
||||
iters_per_epoch,
|
||||
cfg.train_epochs,
|
||||
warmup_epochs=0,
|
||||
eta_min=0)
|
||||
lr = Tensor(lr_scheduler.get_lr())
|
||||
|
||||
# loss scale
|
||||
manager_loss_scale = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
|
||||
|
||||
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001,
|
||||
loss_scale=cfg.loss_scale)
|
||||
|
||||
model = Model(net, loss_fn=loss_, loss_scale_manager=manager_loss_scale, optimizer=optimizer, amp_level="O3")
|
||||
|
||||
# callback for saving ckpts
|
||||
time_cb = TimeMonitor(data_size=iters_per_epoch)
|
||||
loss_cb = LossMonitor()
|
||||
cbs = [time_cb, loss_cb]
|
||||
|
||||
if args.rank == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=cfg.model, directory=cfg.train_dir, config=config_ck)
|
||||
cbs.append(ckpoint_cb)
|
||||
|
||||
model.train(cfg.train_epochs, dataset, callbacks=cbs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
Loading…
Reference in New Issue