dtraineeplabv3plus first commit

This commit is contained in:
kzx2018 2021-03-16 11:16:20 +08:00
parent d8484efdd4
commit 3ed4dee7bf
25 changed files with 2291 additions and 0 deletions

View File

@ -0,0 +1,531 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [DeepLabV3+描述](#deeplabv3+描述)
- [描述](#描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [特性](#特性)
- [混合精度](#混合精度)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [用法](#用法)
- [Ascend处理器环境运行](#ascend处理器环境运行)
- [结果](#结果)
- [评估过程](#评估过程)
- [用法](#用法-1)
- [Ascend处理器环境运行](#ascend处理器环境运行-1)
- [结果](#结果-1)
- [训练准确率](#训练准确率)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# DeepLabV3+描述
## 描述
DeepLab是一系列图像语义分割模型DeepLabv3+通过encoder-decoder进行多尺度信息的融合同时保留了原来的空洞卷积和ASSP层
其骨干网络使用了Resnet模型提高了语义分割的健壮性和运行速率。
有关网络详细信息,请参阅[论文][1]
`Chen, Liang-Chieh, et al. "Encoder-decoder with atrous separable convolution for semantic image segmentation." Proceedings of the European conference on computer vision (ECCV). 2018.`
[1]: https://arxiv.org/abs/1802.02611
# 模型架构
以ResNet-101为骨干通过encoder-decoder进行多尺度信息的融合使用空洞卷积进行密集特征提取。
# 数据集
Pascal VOC数据集和语义边界数据集Semantic Boundaries DatasetSBD
- 下载分段数据集。
- 准备训练数据清单文件。清单文件用于保存图片和标注对的相对路径。如下:
```text
VOCdevkit/VOC2012/JPEGImages/2007_000032.jpg VOCdevkit/VOC2012/SegmentationClassGray/2007_000032.png
VOCdevkit/VOC2012/JPEGImages/2007_000039.jpg VOCdevkit/VOC2012/SegmentationClassGray/2007_000039.png
VOCdevkit/VOC2012/JPEGImages/2007_000063.jpg VOCdevkit/VOC2012/SegmentationClassGray/2007_000063.png
VOCdevkit/VOC2012/JPEGImages/2007_000068.jpg VOCdevkit/VOC2012/SegmentationClassGray/2007_000068.png
......
```
你也可以通过运行脚本:`python get_dataset_list.py --data_root=/PATH/TO/DATA` 来自动生成数据清单文件。
- 配置并运行get_dataset_mindrecord.sh将数据集转换为MindRecords。scripts/get_dataset_mindrecord.sh中的参数
```
--data_root 训练数据的根路径
--data_lst 训练数据列表(如上准备)
--dst_path MindRecord所在路径
--num_shards MindRecord的分片数
--shuffle 是否混洗
```
# 特性
## 混合精度
采用[混合精度](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
以FP16算子为例如果输入数据类型为FP32MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志搜索“reduce precision”查看精度降低的算子。
# 环境要求
- 硬件Ascend
- 准备Ascend处理器搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
- 安装requirements.txt中的python包。
- 生成config json文件用于8卡训练。
```
# 从项目根目录进入
cd src/tools/
python3 get_multicards_json.py 10.111.*.*
# 10.111.*.*为计算机IP地址
```
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
在DeepLabV3+原始论文的基础上我们对VOCaug也称为trainaug数据集进行了两次训练实验并对voc val数据集进行了评估。
运行以下训练脚本配置单卡训练参数:
```bash
run_alone_train.sh
```
按照以下训练步骤进行8卡训练
1.使用VOCaug数据集训练s16微调ResNet-101预训练模型。脚本如下
```bash
run_distribute_train_s16_r1.sh
```
2.使用VOCaug数据集训练s8微调上一步的模型。脚本如下
```bash
run_distribute_train_s8_r1.sh
```
3.使用VOCtrain数据集训练s8微调上一步的模型。脚本如下
```bash
run_distribute_train_s8_r2.sh
```
评估步骤如下:
1.使用voc val数据集评估s16。评估脚本如下
```bash
run_eval_s16.sh
```
2.使用voc val数据集评估多尺度s16。评估脚本如下
```bash
run_eval_s16_multiscale.sh
```
3.使用voc val数据集评估多尺度和翻转s16。评估脚本如下
```bash
run_eval_s16_multiscale_flip.sh
```
4.使用voc val数据集评估s8。评估脚本如下
```bash
run_eval_s8.sh
```
5.使用voc val数据集评估多尺度s8。评估脚本如下
```bash
run_eval_s8_multiscale.sh
```
6.使用voc val数据集评估多尺度和翻转s8。评估脚本如下
```bash
run_eval_s8_multiscale_flip.sh
```
# 脚本说明
## 脚本及样例代码
```shell
.
└──deeplabv3plus
├── script
├── get_dataset_mindrecord.sh # 将原始数据转换为MindRecord数据集
├── run_alone_train.sh # 启动Ascend单机训练单卡
├── run_distribute_train_s16_r1.sh # 使用s16结构的VOCaug数据集启动Ascend分布式训练8卡
├── run_distribute_train_s8_r1.sh # 使用s8结构的VOCaug数据集启动Ascend分布式训练8卡
├── run_distribute_train_s8_r2.sh # 使用s8结构的VOCtrain数据集启动Ascend分布式训练8卡
├── run_eval_s16.sh # 使用s16结构启动Ascend评估
├── run_eval_s16_multiscale.sh # 使用多尺度s16结构启动Ascend评估
├── run_eval_s16_multiscale_filp.sh # 使用多尺度和翻转s16结构启动Ascend评估
├── run_eval_s8.sh # 使用s8结构启动Ascend评估
├── run_eval_s8_multiscale.sh # 使用多尺度s8结构启动Ascend评估
├── run_eval_s8_multiscale_filp.sh # 使用多尺度和翻转s8结构启动Ascend评估
├── src
├── tools
├── get_dataset_list.py # 获取数据清单文件
├── get_dataset_mindrecord.py # 获取MindRecord文件
├── get_multicards_json.py # 获取rank table文件
├── get_pretrained_model.py # 获取resnet预训练模型
├── dataset.py # 数据预处理
├── deeplab_v3plus.py # DeepLabV3+网络结构
├── learning_rates.py # 生成学习率
├── loss.py # DeepLabV3+的损失定义
├── eval.py # 评估网络
├── train.py # 训练网络
├──requirements.txt # requirements文件
└──README.md
```
## 脚本参数
默认配置
```bash
"data_file":"/PATH/TO/MINDRECORD_NAME" # 数据集路径
"device_target":Ascend # 训练后端类型
"train_epochs":300 # 总轮次数
"batch_size":32 # 输入张量的批次大小
"crop_size":513 # 裁剪大小
"base_lr":0.08 # 初始学习率
"lr_type":cos # 用于生成学习率的衰减模式
"min_scale":0.5 # 数据增强的最小尺度
"max_scale":2.0 # 数据增强的最大尺度
"ignore_label":255 # 忽略标签
"num_classes":21 # 类别数
"model":DeepLabV3plus_s16 # 选择模型
"ckpt_pre_trained":"/PATH/TO/PRETRAIN_MODEL" # 加载预训练检查点的路径
"is_distributed": # 分布式训练设置该参数为True
"save_steps":410 # 用于保存的迭代间隙
"freeze_bn": # 设置该参数freeze_bn为True
"keep_checkpoint_max":200 # 用于保存的最大检查点
```
## 训练过程
### 用法
#### Ascend处理器环境运行
在DeepLabV3+原始论文的基础上我们对vocaug也称为trainaug数据集进行了两次训练实验并对voc val数据集进行了评估。
运行以下训练脚本配置单卡训练参数:
```bash
# run_alone_train.sh
python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \
--train_dir=${train_path}/ckpt \
--train_epochs=200 \
--batch_size=32 \
--crop_size=513 \
--base_lr=0.015 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s16 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--save_steps=1500 \
--keep_checkpoint_max=200 >log 2>&1 &
```
按照以下训练步骤进行8卡训练
1.使用VOCaug数据集训练s16微调ResNet-101预训练模型。脚本如下
```bash
# run_distribute_train_s16_r1.sh
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=$i
export DEVICE_ID=`expr $i + $RANK_START_ID`
echo 'start rank='$i', device id='$DEVICE_ID'...'
mkdir ${train_path}/device$DEVICE_ID
cd ${train_path}/device$DEVICE_ID
ython ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
--data_file=/PATH/TO/MINDRECORD_NAME \
--train_epochs=300 \
--batch_size=32 \
--crop_size=513 \
--base_lr=0.08 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s16 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--is_distributed \
--save_steps=410 \
--keep_checkpoint_max=200 >log 2>&1 &
done
```
2.使用VOCaug数据集训练s8微调上一步的模型。脚本如下
```bash
# run_distribute_train_s8_r1.sh
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=$i
export DEVICE_ID=`expr $i + $RANK_START_ID`
echo 'start rank='$i', device id='$DEVICE_ID'...'
mkdir ${train_path}/device$DEVICE_ID
cd ${train_path}/device$DEVICE_ID
python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
--data_file=/PATH/TO/MINDRECORD_NAME \
--train_epochs=800 \
--batch_size=16 \
--crop_size=513 \
--base_lr=0.02 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s8 \
--loss_scale=2048 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--is_distributed \
--save_steps=820 \
--keep_checkpoint_max=200 >log 2>&1 &
done
```
3.使用VOCtrain数据集训练s8微调上一步的模型。脚本如下
```bash
# run_distribute_train_s8_r2.sh
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=$i
export DEVICE_ID=`expr $i + $RANK_START_ID`
echo 'start rank='$i', device id='$DEVICE_ID'...'
mkdir ${train_path}/device$DEVICE_ID
cd ${train_path}/device$DEVICE_ID
python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
--data_file=/PATH/TO/MINDRECORD_NAME \
--train_epochs=300 \
--batch_size=16 \
--crop_size=513 \
--base_lr=0.008 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s8 \
--loss_scale=2048 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--is_distributed \
--save_steps=110 \
--keep_checkpoint_max=200 >log 2>&1 &
done
```
#### ModelArts环境运行
按以下样例配置训练参数启动ModelArts训练
```shell
python train.py --train_url=/PATH/TO/OUTPUT_DIR \
--data_url=/PATH/TO/MINDRECORD \
--model=DeepLabV3plus_s16 \
--modelArts_mode=True \
--dataset_filename=MINDRECORD_NAME \
--pretrainedmodel_filename=PRETRAIN_MODELNAME \
--train_epochs=300 \
--batch_size=32 \
--crop_size=513 \
--base_lr=0.08 \
--lr_type=cos \
--save_steps=410 \
```
### 结果
#### Ascend处理器环境运行
- 使用s16结构训练VOCaug
```bash
# 分布式训练结果8P
epoch: 1 step: 41, loss is 0.81338423
epoch time: 202199.339 ms, per step time: 4931.691 ms
epoch: 2 step: 41, loss is 0.34089813
epoch time: 23811.338 ms, per step time: 580.764 ms
epoch: 3 step: 41, loss is 0.32335973
epoch time: 23794.863 ms, per step time: 580.363 ms
epoch: 4 step: 41, loss is 0.18254203
epoch time: 23796.674 ms, per step time: 580.407 ms
epoch: 5 step: 41, loss is 0.27708685
epoch time: 23794.654 ms, per step time: 580.357 ms
epoch: 6 step: 41, loss is 0.37388346
epoch time: 23845.658 ms, per step time: 581.601 ms
...
```
- 使用s8结构训练VOCaug
```bash
# 分布式训练结果8P
epoch: 1 step: 82, loss is 0.073864505
epoch time: 226610.999 ms, per step time: 2763.549 ms
epoch: 2 step: 82, loss is 0.06908825
epoch time: 44474.187 ms, per step time: 542.368 ms
epoch: 3 step: 82, loss is 0.059860937
epoch time: 44485.142 ms, per step time: 542.502 ms
epoch: 4 step: 82, loss is 0.084193744
epoch time: 44472.924 ms, per step time: 542.353 ms
epoch: 5 step: 82, loss is 0.072242916
epoch time: 44466.738 ms, per step time: 542.277 ms
epoch: 6 step: 82, loss is 0.04948996
epoch time: 44474.549 ms, per step time: 542.373 ms
...
```
- 使用s8结构训练VOCtrain
```bash
# 分布式训练结果8P
epoch: 1 step: 11, loss is 0.0055908263
epoch time: 183966.044 ms, per step time: 16724.186 ms
epoch: 2 step: 11, loss is 0.008914589
epoch time: 5985.108 ms, per step time: 544.101 ms
epoch: 3 step: 11, loss is 0.0073758443
epoch time: 5977.932 ms, per step time: 543.448 ms
epoch: 4 step: 11, loss is 0.00677738
epoch time: 5978.866 ms, per step time: 543.533 ms
epoch: 5 step: 11, loss is 0.0053799236
epoch time: 5987.879 ms, per step time: 544.353 ms
epoch: 6 step: 11, loss is 0.0049248594
epoch time: 5979.642 ms, per step time: 543.604 ms
...
```
#### ModelArts环境运行
- 使用s16结构训练VOCaug
```bash
epoch: 1 step: 41, loss is 0.6122837
epoch: 2 step: 41, loss is 0.4066103
epoch: 3 step: 41, loss is 0.3504579
...
```
## 评估过程
### 用法
#### Ascend处理器环境运行
使用--ckpt_path配置检查点运行脚本在eval_path/eval_log中打印mIOU。
```bash
./run_eval_s16.sh # 测试s16
./run_eval_s16_multiscale.sh # 测试s16 + 多尺度
./run_eval_s16_multiscale_flip.sh # 测试s16 + 多尺度 + 翻转
./run_eval_s8.sh # 测试s8
./run_eval_s8_multiscale.sh # 测试s8 + 多尺度
./run_eval_s8_multiscale_flip.sh # 测试s8 + 多尺度 + 翻转
```
测试脚本示例如下:
```bash
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=16 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s8 \
--scales=0.5 \
--scales=0.75 \
--scales=1.0 \
--scales=1.25 \
--scales=1.75 \
--flip \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &
```
### 结果
运行适用的训练脚本获取结果。要获得相同的结果,请按照快速入门中的步骤操作。
#### 训练准确率
| **网络** | OS=16 | OS=8 | MS |翻转| mIOU |论文中的mIOU |
| :----------: | :-----: | :----: | :----: | :-----: | :-----: | :-------------: |
| deeplab_v3+ | √ | | | | 79.78 | 78.85 |
| deeplab_v3+ | √ | | √ | | 80.59 |80.09 |
| deeplab_v3+ | √ | | √ | √ | 80.76 | 80.22 |
| deeplab_v3+ | | √ | | | 79.56 | 79.35 |
| deeplab_v3+ | | √ | √ | | 80.43 |80.43 |
| deeplab_v3+ | | √ | √ | √ | 80.69 | 80.57 |
注意OS指输出步长output stride MS指多尺度multiscale
# 模型描述
## 性能
### 评估性能
| 参数 | Ascend 910|
| -------------------------- | -------------------------------------- |
| 模型版本 | DeepLabV3+ |
| 资源 | Ascend 910 |
| 上传日期 | 2021-03-16 |
| MindSpore版本 | 1.1.1 |
| 数据集 | PASCAL VOC2012 + SBD |
| 训练参数 | epoch = 300, batch_size = 32 (s16_r1) epoch = 800, batch_size = 16 (s8_r1) epoch = 300, batch_size = 16 (s8_r2) |
| 优化器 | Momentum |
| 损失函数 | Softmax交叉熵 |
| 输出 | 概率 |
| 损失 | 0.0041095633 |
| 性能 | 187736.386 ms单卡s16<br> 44474.187 ms八卡s16 |
| 微调检查点 | 453M .ckpt文件 |
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/deeplabv3plus) |
# 随机情况说明
dataset.py中设置了“create_dataset”函数内的种子同时还使用了train.py中的随机种子。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,218 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval deeplabv3+"""
import os
import argparse
import numpy as np
import cv2
from mindspore import Tensor
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.deeplab_v3plus import DeepLabV3Plus
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
device_id=int(os.getenv('DEVICE_ID')))
def parse_args():
"""parse_args"""
parser = argparse.ArgumentParser('MindSpore DeepLabV3+ eval')
# val data
parser.add_argument('--data_root', type=str, default='', help='root path of val data')
parser.add_argument('--data_lst', type=str, default='', help='list of val data')
parser.add_argument('--batch_size', type=int, default=16, help='batch size')
parser.add_argument('--crop_size', type=int, default=513, help='crop size')
parser.add_argument('--image_mean', type=list, default=[103.53, 116.28, 123.675], help='image mean')
parser.add_argument('--image_std', type=list, default=[57.375, 57.120, 58.395], help='image std')
parser.add_argument('--scales', type=float, action='append', help='scales of evaluation')
parser.add_argument('--flip', action='store_true', help='perform left-right flip')
parser.add_argument('--ignore_label', type=int, default=255, help='ignore label')
parser.add_argument('--num_classes', type=int, default=21, help='number of classes')
# model
parser.add_argument('--model', type=str, default='', help='select model')
parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn')
parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate')
args, _ = parser.parse_known_args()
return args
def cal_hist(a, b, n):
k = (a >= 0) & (a < n)
return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n)
def resize_long(img, long_size=513):
h, w, _ = img.shape
if h > w:
new_h = long_size
new_w = int(1.0 * long_size * w / h)
else:
new_w = long_size
new_h = int(1.0 * long_size * h / w)
imo = cv2.resize(img, (new_w, new_h))
return imo
class BuildEvalNetwork(nn.Cell):
def __init__(self, network):
super(BuildEvalNetwork, self).__init__()
self.network = network
self.softmax = nn.Softmax(axis=1)
def construct(self, input_data):
output = self.network(input_data)
output = self.softmax(output)
return output
def pre_process(args, img_, crop_size=513):
"""pre_process"""
# resize
img_ = resize_long(img_, crop_size)
resize_h, resize_w, _ = img_.shape
# mean, std
image_mean = np.array(args.image_mean)
image_std = np.array(args.image_std)
img_ = (img_ - image_mean) / image_std
# pad to crop_size
pad_h = crop_size - img_.shape[0]
pad_w = crop_size - img_.shape[1]
if pad_h > 0 or pad_w > 0:
img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
# hwc to chw
img_ = img_.transpose((2, 0, 1))
return img_, resize_h, resize_w
def eval_batch(args, eval_net, img_lst, crop_size=513, flip=True):
"""eval_batch"""
result_lst = []
batch_size = len(img_lst)
batch_img = np.zeros((args.batch_size, 3, crop_size, crop_size), dtype=np.float32)
resize_hw = []
for l in range(batch_size):
img_ = img_lst[l]
img_, resize_h, resize_w = pre_process(args, img_, crop_size)
batch_img[l] = img_
resize_hw.append([resize_h, resize_w])
batch_img = np.ascontiguousarray(batch_img)
net_out = eval_net(Tensor(batch_img, mstype.float32))
net_out = net_out.asnumpy()
if flip:
batch_img = batch_img[:, :, :, ::-1]
net_out_flip = eval_net(Tensor(batch_img, mstype.float32))
net_out += net_out_flip.asnumpy()[:, :, :, ::-1]
for bs in range(batch_size):
probs_ = net_out[bs][:, :resize_hw[bs][0], :resize_hw[bs][1]].transpose((1, 2, 0))
ori_h, ori_w = img_lst[bs].shape[0], img_lst[bs].shape[1]
probs_ = cv2.resize(probs_, (ori_w, ori_h))
result_lst.append(probs_)
return result_lst
def eval_batch_scales(args, eval_net, img_lst, scales,
base_crop_size=513, flip=True):
"""eval_batch_scales"""
sizes_ = [int((base_crop_size - 1) * sc) + 1 for sc in scales]
probs_lst = eval_batch(args, eval_net, img_lst, crop_size=sizes_[0], flip=flip)
print(sizes_)
for crop_size_ in sizes_[1:]:
probs_lst_tmp = eval_batch(args, eval_net, img_lst, crop_size=crop_size_, flip=flip)
for pl, _ in enumerate(probs_lst):
probs_lst[pl] += probs_lst_tmp[pl]
result_msk = []
for i in probs_lst:
result_msk.append(i.argmax(axis=2))
return result_msk
def net_eval():
"""net_eval"""
args = parse_args()
# data list
with open(args.data_lst) as f:
img_lst = f.readlines()
# network
if args.model == 'DeepLabV3plus_s16':
network = DeepLabV3Plus('eval', args.num_classes, 16, args.freeze_bn)
elif args.model == 'DeepLabV3plus_s8':
network = DeepLabV3Plus('eval', args.num_classes, 8, args.freeze_bn)
else:
raise NotImplementedError('model [{:s}] not recognized'.format(args.model))
eval_net = BuildEvalNetwork(network)
# load model
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(eval_net, param_dict)
eval_net.set_train(False)
# evaluate
hist = np.zeros((args.num_classes, args.num_classes))
batch_img_lst = []
batch_msk_lst = []
bi = 0
image_num = 0
for i, line in enumerate(img_lst):
img_path, msk_path = line.strip().split(' ')
img_path = os.path.join(args.data_root, img_path)
msk_path = os.path.join(args.data_root, msk_path)
img_ = cv2.imread(img_path)
msk_ = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)
batch_img_lst.append(img_)
batch_msk_lst.append(msk_)
bi += 1
if bi == args.batch_size:
batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales,
base_crop_size=args.crop_size, flip=args.flip)
for mi in range(args.batch_size):
hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes)
bi = 0
batch_img_lst = []
batch_msk_lst = []
print('processed {} images'.format(i + 1))
image_num = i
if bi > 0:
batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales,
base_crop_size=args.crop_size, flip=args.flip)
for mi in range(bi):
hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes)
print('processed {} images'.format(image_num + 1))
print(hist)
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
print('per-class IoU', iu)
print('mean IoU', np.nanmean(iu))
if __name__ == '__main__':
net_eval()

View File

@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export AIR file."""
import argparse
import numpy as np
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.deeplab_v3plus import DeepLabV3Plus
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='checkpoint export')
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of deeplabv3 (Default: None)')
parser.add_argument('--model', type=str, default='DeepLabV3plus_s8',
choices=['DeepLabV3plus_s16', 'DeepLabV3plus_s8'],
help='Select model structure (Default: DeepLabV3plus_s8)')
parser.add_argument('--num_classes', type=int, default=21, help='the number of classes (Default: 21)')
args = parser.parse_args()
if args.model == 'DeepLabV3plus_s16':
network = DeepLabV3Plus('eval', args.num_classes, 16, True)
else:
network = DeepLabV3Plus('eval', args.num_classes, 8, True)
param_dict = load_checkpoint(args.checkpoint)
# load the parameter into net
load_param_into_net(network, param_dict)
input_data = np.random.uniform(0.0, 1.0, size=[32, 3, 513, 513]).astype(np.float32)
export(network, Tensor(input_data), file_name=args.model + '-300_11.air', file_format='AIR')

View File

@ -0,0 +1,28 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.deeplab_v3plus import DeepLabV3Plus
def create_network(name, *args, **kwargs):
freeze_bn = True
num_classes = kwargs["num_classes"]
if name == 'DeepLabV3plus_s16':
DeepLabV3plus_s16_network = DeepLabV3Plus('eval', num_classes, 16, freeze_bn)
return DeepLabV3plus_s16_network
if name == 'DeepLabV3plus_s8':
DeepLabV3plus_s8_network = DeepLabV3Plus('eval', num_classes, 8, freeze_bn)
return DeepLabV3plus_s8_network
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,4 @@
mindspore
numpy
Pillow
python-opencv

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=7
python /PATH/TO/MODEL_ZOO_CODE/data/get_dataset_mindrecord.py --data_root=/PATH/TO/DATA_ROOT \
--data_lst=/PATH/TO/DATA_lst.txt \
--dst_path=/PATH/TO/MINDRECORED_NAME.mindrecord \
--num_shards=1 \
--shuffle=True

View File

@ -0,0 +1,44 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=5
export SLOG_PRINT_TO_STDOUT=0
train_path=/PATH/TO/EXPERIMENTS_DIR
train_code_path=/PATH/TO/MODEL_ZOO_CODE
if [ -d ${train_path} ]; then
rm -rf ${train_path}
fi
mkdir -p ${train_path}
mkdir ${train_path}/device${DEVICE_ID}
mkdir ${train_path}/ckpt
cd ${train_path}/device${DEVICE_ID} || exit
python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \
--train_dir=${train_path}/ckpt \
--train_epochs=200 \
--batch_size=32 \
--crop_size=513 \
--base_lr=0.015 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s16 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--save_steps=1500 \
--keep_checkpoint_max=200 >log 2>&1 &

View File

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -c unlimited
train_path=/PATH/TO/EXPERIMENTS_DIR
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json
export RANK_SIZE=8
export RANK_START_ID=0
if [ -d ${train_path} ]; then
rm -rf ${train_path}
fi
mkdir -p ${train_path}
mkdir ${train_path}/ckpt
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=${i}
export DEVICE_ID=$((i + RANK_START_ID))
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
mkdir ${train_path}/device${DEVICE_ID}
cd ${train_path}/device${DEVICE_ID} || exit
python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
--data_file=/PATH/TO/MINDRECORD_NAME \
--train_epochs=300 \
--batch_size=32 \
--crop_size=513 \
--base_lr=0.08 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s16 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--is_distributed \
--save_steps=410 \
--keep_checkpoint_max=200 >log 2>&1 &
done

View File

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -c unlimited
train_path=/PATH/TO/EXPERIMENTS_DIR
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json
export RANK_SIZE=8
export RANK_START_ID=0
if [ -d ${train_path} ]; then
rm -rf ${train_path}
fi
mkdir -p ${train_path}
mkdir ${train_path}/ckpt
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=${i}
export DEVICE_ID=$((i + RANK_START_ID))
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
mkdir ${train_path}/device${DEVICE_ID}
cd ${train_path}/device${DEVICE_ID} || exit
python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
--data_file=/PATH/TO/MINDRECORD_NAME \
--train_epochs=800 \
--batch_size=16 \
--crop_size=513 \
--base_lr=0.02 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s8 \
--loss_scale=2048 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--is_distributed \
--save_steps=820 \
--keep_checkpoint_max=200 >log 2>&1 &
done

View File

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -c unlimited
train_path=/PATH/TO/EXPERIMENTS_DIR
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json
export RANK_SIZE=8
export RANK_START_ID=0
if [ -d ${train_path} ]; then
rm -rf ${train_path}
fi
mkdir -p ${train_path}
mkdir ${train_path}/ckpt
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=${i}
export DEVICE_ID=$((i + RANK_START_ID))
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
mkdir ${train_path}/device${DEVICE_ID}
cd ${train_path}/device${DEVICE_ID} || exit
python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
--data_file=/PATH/TO/MINDRECORD_NAME \
--train_epochs=300 \
--batch_size=16 \
--crop_size=513 \
--base_lr=0.008 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s8 \
--loss_scale=2048 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--is_distributed \
--save_steps=110 \
--keep_checkpoint_max=200 >log 2>&1 &
done

View File

@ -0,0 +1,37 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=3
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
eval_path=/PATH/TO/EVAL
if [ -d ${eval_path} ]; then
rm -rf ${eval_path}
fi
mkdir -p ${eval_path}
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=32 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s16 \
--scales=1.0 \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &

View File

@ -0,0 +1,40 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=3
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
eval_path=/PATH/TO/EVAL
if [ -d ${eval_path} ]; then
rm -rf ${eval_path}
fi
mkdir -p ${eval_path}
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=16 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s16 \
--scales=0.5 \
--scales=0.75 \
--scales=1.0 \
--scales=1.25 \
--scales=1.75 \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &

View File

@ -0,0 +1,42 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=3
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
eval_path=/PATH/TO/EVAL
if [ -d ${eval_path} ]; then
rm -rf ${eval_path}
fi
mkdir -p ${eval_path}
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=16 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s16 \
--scales=0.5 \
--scales=0.75 \
--scales=1.0 \
--scales=1.25 \
--scales=1.75 \
--flip \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &

View File

@ -0,0 +1,37 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=3
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
eval_path=/PATH/TO/EVAL
if [ -d ${eval_path} ]; then
rm -rf ${eval_path}
fi
mkdir -p ${eval_path}
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=16 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s8 \
--scales=1.0 \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &

View File

@ -0,0 +1,41 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=3
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
eval_path=/PATH/TO/EVAL
if [ -d ${eval_path} ]; then
rm -rf ${eval_path}
fi
mkdir -p ${eval_path}
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=16 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s8 \
--scales=0.5 \
--scales=0.75 \
--scales=1.0 \
--scales=1.25 \
--scales=1.75 \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &

View File

@ -0,0 +1,42 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=3
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
eval_path=/PATH/TO/EVAL
if [ -d ${eval_path} ]; then
rm -rf ${eval_path}
fi
mkdir -p ${eval_path}
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=16 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=DeepLabV3plus_s8 \
--scales=0.5 \
--scales=0.75 \
--scales=1.0 \
--scales=1.25 \
--scales=1.75 \
--flip \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &

View File

@ -0,0 +1,99 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dataset"""
import numpy as np
import cv2
import mindspore.dataset as de
cv2.setNumThreads(0)
class SegDataset:
"""SegDataset"""
def __init__(self,
image_mean,
image_std,
data_file='',
batch_size=32,
crop_size=512,
max_scale=2.0,
min_scale=0.5,
ignore_label=255,
num_classes=21,
num_readers=2,
num_parallel_calls=4,
shard_id=None,
shard_num=None):
self.data_file = data_file
self.batch_size = batch_size
self.crop_size = crop_size
self.image_mean = np.array(image_mean, dtype=np.float32)
self.image_std = np.array(image_std, dtype=np.float32)
self.max_scale = max_scale
self.min_scale = min_scale
self.ignore_label = ignore_label
self.num_classes = num_classes
self.num_readers = num_readers
self.num_parallel_calls = num_parallel_calls
self.shard_id = shard_id
self.shard_num = shard_num
assert max_scale > min_scale
def preprocess_(self, image, label):
"""SegDataset.preprocess_"""
# bgr image
image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
sc = np.random.uniform(self.min_scale, self.max_scale)
new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
image_out = (image_out - self.image_mean) / self.image_std
h_, w_ = max(new_h, self.crop_size), max(new_w, self.crop_size)
pad_h, pad_w = h_ - new_h, w_ - new_w
if pad_h > 0 or pad_w > 0:
image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)
offset_h = np.random.randint(0, h_ - self.crop_size + 1)
offset_w = np.random.randint(0, w_ - self.crop_size + 1)
image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]
label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size]
if np.random.uniform(0.0, 1.0) > 0.5:
image_out = image_out[:, ::-1, :]
label_out = label_out[:, ::-1]
image_out = image_out.transpose((2, 0, 1))
image_out = image_out.copy()
label_out = label_out.copy()
return image_out, label_out
def get_dataset(self, repeat=1):
"""SegDataset.get_dataset"""
data_set = de.MindDataset(dataset_file=self.data_file, columns_list=["data", "label"],
shuffle=True, num_parallel_workers=self.num_readers,
num_shards=self.shard_num, shard_id=self.shard_id)
transforms_list = self.preprocess_
data_set = data_set.map(operations=transforms_list, input_columns=["data", "label"],
output_columns=["data", "label"],
num_parallel_workers=self.num_parallel_calls)
data_set = data_set.shuffle(buffer_size=self.batch_size * 10)
data_set = data_set.batch(self.batch_size, drop_remainder=True)
data_set = data_set.repeat(repeat)
return data_set

View File

@ -0,0 +1,252 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""deeplabv3plus network"""
import mindspore.nn as nn
from mindspore.ops import operations as P
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, weight_init='xavier_uniform')
def conv3x3(in_planes, out_planes, stride=1, dilation=1, padding=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, pad_mode='pad', padding=padding,
dilation=dilation, weight_init='xavier_uniform')
class Resnet(nn.Cell):
"""Resnet"""
def __init__(self, block, block_num, output_stride, use_batch_statistics=True):
super(Resnet, self).__init__()
self.inplanes = 64
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, pad_mode='pad', padding=3,
weight_init='xavier_uniform')
self.bn1 = nn.BatchNorm2d(self.inplanes, use_batch_statistics=use_batch_statistics)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self._make_layer(block, 64, block_num[0], use_batch_statistics=use_batch_statistics)
self.layer2 = self._make_layer(block, 128, block_num[1], stride=2, use_batch_statistics=use_batch_statistics)
if output_stride == 16:
self.layer3 = self._make_layer(block, 256, block_num[2], stride=2,
use_batch_statistics=use_batch_statistics)
self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=2, grids=[1, 2, 4],
use_batch_statistics=use_batch_statistics)
elif output_stride == 8:
self.layer3 = self._make_layer(block, 256, block_num[2], stride=1, base_dilation=2,
use_batch_statistics=use_batch_statistics)
self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=4, grids=[1, 2, 4],
use_batch_statistics=use_batch_statistics)
def _make_layer(self, block, planes, blocks, stride=1, base_dilation=1, grids=None, use_batch_statistics=True):
"""Resnet._make_layer"""
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.SequentialCell([
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion, use_batch_statistics=use_batch_statistics)
])
if grids is None:
grids = [1] * blocks
layers = [
block(self.inplanes, planes, stride, downsample, dilation=base_dilation * grids[0],
use_batch_statistics=use_batch_statistics)
]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(self.inplanes, planes, dilation=base_dilation * grids[i],
use_batch_statistics=use_batch_statistics))
return nn.SequentialCell(layers)
def construct(self, x):
"""Resnet.construct"""
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.maxpool(out)
out = self.layer1(out)
low_level_feat = out
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
return out, low_level_feat
class Bottleneck(nn.Cell):
"""Bottleneck"""
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, use_batch_statistics=True):
super(Bottleneck, self).__init__()
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics)
self.conv2 = conv3x3(planes, planes, stride, dilation, dilation)
self.bn2 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics)
self.conv3 = conv1x1(planes, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion, use_batch_statistics=use_batch_statistics)
self.relu = nn.ReLU()
self.downsample = downsample
def construct(self, x):
"""Bottleneck.construct"""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out = out + identity
out = self.relu(out)
return out
class ASPPConv(nn.Cell):
"""ASPPConv"""
def __init__(self, in_channels, out_channels, atrous_rate=1, use_batch_statistics=True):
super(ASPPConv, self).__init__()
if atrous_rate == 1:
conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, has_bias=False, weight_init='xavier_uniform')
else:
conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, pad_mode='pad', padding=atrous_rate,
dilation=atrous_rate, weight_init='xavier_uniform')
bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
relu = nn.ReLU()
self.aspp_conv = nn.SequentialCell([conv, bn, relu])
def construct(self, x):
out = self.aspp_conv(x)
return out
class ASPPPooling(nn.Cell):
"""ASPPPooling"""
def __init__(self, in_channels, out_channels, use_batch_statistics=True):
super(ASPPPooling, self).__init__()
self.conv = nn.SequentialCell([
nn.Conv2d(in_channels, out_channels, kernel_size=1, weight_init='xavier_uniform'),
nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics),
nn.ReLU()
])
self.shape = P.Shape()
def construct(self, x):
size = self.shape(x)
out = nn.AvgPool2d(size[2])(x)
out = self.conv(out)
out = P.ResizeNearestNeighbor((size[2], size[3]), True)(out)
return out
class ASPP(nn.Cell):
"""ASPP"""
def __init__(self, atrous_rates, phase='train', in_channels=2048, num_classes=21,
use_batch_statistics=True):
super(ASPP, self).__init__()
self.phase = phase
out_channels = 256
self.aspp1 = ASPPConv(in_channels, out_channels, atrous_rates[0], use_batch_statistics=use_batch_statistics)
self.aspp2 = ASPPConv(in_channels, out_channels, atrous_rates[1], use_batch_statistics=use_batch_statistics)
self.aspp3 = ASPPConv(in_channels, out_channels, atrous_rates[2], use_batch_statistics=use_batch_statistics)
self.aspp4 = ASPPConv(in_channels, out_channels, atrous_rates[3], use_batch_statistics=use_batch_statistics)
self.aspp_pooling = ASPPPooling(in_channels, out_channels)
self.conv1 = nn.Conv2d(out_channels * (len(atrous_rates) + 1), out_channels, kernel_size=1,
weight_init='xavier_uniform')
self.bn1 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
self.relu = nn.ReLU()
self.concat = P.Concat(axis=1)
self.drop = nn.Dropout(0.3)
def construct(self, x):
"""ASPP.construct"""
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.aspp_pooling(x)
x = self.concat((x1, x2))
x = self.concat((x, x3))
x = self.concat((x, x4))
x = self.concat((x, x5))
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
if self.phase == 'train':
x = self.drop(x)
return x
class DeepLabV3Plus(nn.Cell):
"""DeepLabV3Plus"""
def __init__(self, phase='train', num_classes=21, output_stride=16, freeze_bn=False):
super(DeepLabV3Plus, self).__init__()
use_batch_statistics = not freeze_bn
self.resnet = Resnet(Bottleneck, [3, 4, 23, 3], output_stride=output_stride,
use_batch_statistics=use_batch_statistics)
self.aspp = ASPP([1, 6, 12, 18], phase, 2048, num_classes,
use_batch_statistics=use_batch_statistics)
self.shape = P.Shape()
self.conv2 = nn.Conv2d(256, 48, kernel_size=1, weight_init='xavier_uniform')
self.bn2 = nn.BatchNorm2d(48, use_batch_statistics=use_batch_statistics)
self.relu = nn.ReLU()
self.concat = P.Concat(axis=1)
self.last_conv = nn.SequentialCell([
conv3x3(304, 256, stride=1, dilation=1, padding=1),
nn.BatchNorm2d(256, use_batch_statistics=use_batch_statistics),
nn.ReLU(),
conv3x3(256, 256, stride=1, dilation=1, padding=1),
nn.BatchNorm2d(256, use_batch_statistics=use_batch_statistics),
nn.ReLU(),
conv1x1(256, num_classes, stride=1)
])
def construct(self, x):
"""DeepLabV3Plus.construct"""
size = self.shape(x)
out, low_level_features = self.resnet(x)
size2 = self.shape(low_level_features)
out = self.aspp(out)
out = P.ResizeNearestNeighbor((size2[2], size2[3]), True)(out)
low_level_features = self.conv2(low_level_features)
low_level_features = self.bn2(low_level_features)
low_level_features = self.relu(low_level_features)
out = self.concat((out, low_level_features))
out = self.last_conv(out)
out = P.ResizeBilinear((size[2], size[3]), True)(out)
return out

View File

@ -0,0 +1,37 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rates"""
import numpy as np
def cosine_lr(base_lr, decay_steps, total_steps):
for i in range(total_steps):
step_ = min(i, decay_steps)
yield base_lr * 0.5 * (1 + np.cos(np.pi * step_ / decay_steps))
def poly_lr(base_lr, decay_steps, total_steps, end_lr=0.0001, power=0.9):
for i in range(total_steps):
step_ = min(i, decay_steps)
yield (base_lr - end_lr) * ((1.0 - step_ / decay_steps) ** power) + end_lr
def exponential_lr(base_lr, decay_steps, decay_rate, total_steps, staircase=False):
for i in range(total_steps):
if staircase:
power_ = i // decay_steps
else:
power_ = float(i) / decay_steps
yield base_lr * (decay_rate ** power_)

View File

@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""loss"""
from mindspore import Tensor
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
class SoftmaxCrossEntropyLoss(nn.Cell):
"""SoftmaxCrossEntropyLoss"""
def __init__(self, num_cls=21, ignore_label=255):
super(SoftmaxCrossEntropyLoss, self).__init__()
self.one_hot = P.OneHot(axis=-1)
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.cast = P.Cast()
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.not_equal = P.NotEqual()
self.num_cls = num_cls
self.ignore_label = ignore_label
self.mul = P.Mul()
self.sum = P.ReduceSum(False)
self.div = P.RealDiv()
self.transpose = P.Transpose()
self.reshape = P.Reshape()
def construct(self, logits, labels):
"""SoftmaxCrossEntropyLoss.construct"""
labels_int = self.cast(labels, mstype.int32)
labels_int = self.reshape(labels_int, (-1,))
logits_ = self.transpose(logits, (0, 2, 3, 1))
logits_ = self.reshape(logits_, (-1, self.num_cls))
weights = self.not_equal(labels_int, self.ignore_label)
weights = self.cast(weights, mstype.float32)
one_hot_labels = self.one_hot(labels_int, self.num_cls, self.on_value, self.off_value)
loss = self.ce(logits_, one_hot_labels)
loss = self.mul(weights, loss)
loss = self.div(self.sum(loss), self.sum(weights))
return loss

View File

@ -0,0 +1,159 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""get_dataset_list"""
import argparse
import os
import numpy as np
import scipy.io
from PIL import Image
parser = argparse.ArgumentParser('dataset list generator')
parser.add_argument("--data_dir", type=str, default='./', help='where dataset stored.')
args, _ = parser.parse_known_args()
data_dir = args.data_dir
print("Data dir is:", data_dir)
VOC_IMG_DIR = os.path.join(data_dir, 'VOCdevkit/VOC2012/JPEGImages')
VOC_ANNO_DIR = os.path.join(data_dir, 'VOCdevkit/VOC2012/SegmentationClass')
VOC_ANNO_GRAY_DIR = os.path.join(data_dir, 'VOCdevkit/VOC2012/SegmentationClassGray')
VOC_TRAIN_TXT = os.path.join(data_dir, 'VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt')
VOC_VAL_TXT = os.path.join(data_dir, 'VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt')
SBD_ANNO_DIR = os.path.join(data_dir, 'benchmark_RELEASE/dataset/cls')
SBD_IMG_DIR = os.path.join(data_dir, 'benchmark_RELEASE/dataset/img')
SBD_ANNO_PNG_DIR = os.path.join(data_dir, 'benchmark_RELEASE/dataset/cls_png')
SBD_ANNO_GRAY_DIR = os.path.join(data_dir, 'benchmark_RELEASE/dataset/cls_png_gray')
SBD_TRAIN_TXT = os.path.join(data_dir, 'benchmark_RELEASE/dataset/train.txt')
SBD_VAL_TXT = os.path.join(data_dir, 'benchmark_RELEASE/dataset/val.txt')
VOC_TRAIN_LST_TXT = os.path.join(data_dir, 'voc_train_lst.txt')
VOC_VAL_LST_TXT = os.path.join(data_dir, 'voc_val_lst.txt')
VOC_AUG_TRAIN_LST_TXT = os.path.join(data_dir, 'vocaug_train_lst.txt')
def __get_data_list(data_list_file):
with open(data_list_file, mode='r') as f:
return f.readlines()
def conv_voc_colorpng_to_graypng():
if not os.path.exists(VOC_ANNO_GRAY_DIR):
os.makedirs(VOC_ANNO_GRAY_DIR)
for ann in os.listdir(VOC_ANNO_DIR):
ann_im = Image.open(os.path.join(VOC_ANNO_DIR, ann))
ann_im = Image.fromarray(np.array(ann_im))
ann_im.save(os.path.join(VOC_ANNO_GRAY_DIR, ann))
def __gen_palette(cls_nums=256):
"""__gen_palette"""
palette = np.zeros((cls_nums, 3), dtype=np.uint8)
for i in range(cls_nums):
lbl = i
j = 0
while lbl:
palette[i, 0] |= (((lbl >> 0) & 1) << (7 - j))
palette[i, 1] |= (((lbl >> 1) & 1) << (7 - j))
palette[i, 2] |= (((lbl >> 2) & 1) << (7 - j))
lbl >>= 3
j += 1
return palette.flatten()
def conv_sbd_mat_to_png():
"""conv_sbd_mat_to_png"""
if not os.path.exists(SBD_ANNO_PNG_DIR):
os.makedirs(SBD_ANNO_PNG_DIR)
if not os.path.exists(SBD_ANNO_GRAY_DIR):
os.makedirs(SBD_ANNO_GRAY_DIR)
palette = __gen_palette()
for an in os.listdir(SBD_ANNO_DIR):
img_id = an[:-4]
mat = scipy.io.loadmat(os.path.join(SBD_ANNO_DIR, an))
anno = mat['GTcls'][0]['Segmentation'][0].astype(np.uint8)
anno_png = Image.fromarray(anno)
# save to gray png
anno_png.save(os.path.join(SBD_ANNO_GRAY_DIR, img_id + '.png'))
# save to color png use palette
anno_png.putpalette(palette)
anno_png.save(os.path.join(SBD_ANNO_PNG_DIR, img_id + '.png'))
def create_voc_train_lst_txt():
voc_train_data_lst = __get_data_list(VOC_TRAIN_TXT)
with open(VOC_TRAIN_LST_TXT, mode='w') as f:
for id_ in voc_train_data_lst:
id_ = id_.strip()
img_ = os.path.join(VOC_IMG_DIR, id_ + '.jpg').replace('./', '')
anno_ = os.path.join(VOC_ANNO_GRAY_DIR, id_ + '.png').replace('./', '')
f.write(img_ + ' ' + anno_ + '\n')
def create_voc_val_lst_txt():
voc_val_data_lst = __get_data_list(VOC_VAL_TXT)
with open(VOC_VAL_LST_TXT, mode='w') as f:
for id_ in voc_val_data_lst:
id_ = id_.strip()
img_ = os.path.join(VOC_IMG_DIR, id_ + '.jpg').replace('./', '')
anno_ = os.path.join(VOC_ANNO_GRAY_DIR, id_ + '.png').replace('./', '')
f.write(img_ + ' ' + anno_ + '\n')
def create_voc_train_aug_lst_txt():
"""create_voc_train_aug_lst_txt"""
voc_train_data_lst = __get_data_list(VOC_TRAIN_TXT)
voc_val_data_lst = __get_data_list(VOC_VAL_TXT)
sbd_train_data_lst = __get_data_list(SBD_TRAIN_TXT)
sbd_val_data_lst = __get_data_list(SBD_VAL_TXT)
with open(VOC_AUG_TRAIN_LST_TXT, mode='w') as f:
for id_ in sbd_train_data_lst + sbd_val_data_lst:
if id_ in voc_train_data_lst + voc_val_data_lst:
continue
id_ = id_.strip()
img_ = os.path.join(SBD_IMG_DIR, id_ + '.jpg').replace('./', '')
anno_ = os.path.join(SBD_ANNO_GRAY_DIR, id_ + '.png').replace('./', '')
f.write(img_ + ' ' + anno_ + '\n')
for id_ in voc_train_data_lst:
id_ = id_.strip()
img_ = os.path.join(VOC_IMG_DIR, id_ + '.jpg').replace('./', '')
anno_ = os.path.join(VOC_ANNO_GRAY_DIR, id_ + '.png').replace('./', '')
f.write(img_ + ' ' + anno_ + '\n')
if __name__ == '__main__':
print('converting voc color png to gray png ...')
conv_voc_colorpng_to_graypng()
print('converting done.')
create_voc_train_lst_txt()
print('generating voc train list success.')
create_voc_val_lst_txt()
print('generating voc val list success.')
print('converting sbd annotations to png ...')
conv_sbd_mat_to_png()
print('converting done')
create_voc_train_aug_lst_txt()
print('generating voc train aug list success.')

View File

@ -0,0 +1,72 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""get_dataset_mindrecord"""
import os
import argparse
import numpy as np
from mindspore.mindrecord import FileWriter
seg_schema = {"file_name": {"type": "string"}, "label": {"type": "bytes"}, "data": {"type": "bytes"}}
def parse_args():
"""parse_args"""
parser = argparse.ArgumentParser('mindrecord')
parser.add_argument('--data_root', type=str, default='', help='root path of data')
parser.add_argument('--data_lst', type=str, default='', help='list of data')
parser.add_argument('--dst_path', type=str, default='', help='save path of mindrecords')
parser.add_argument('--num_shards', type=int, default=1, help='number of shards')
parser.add_argument('--shuffle', type=bool, default=True, help='shuffle or not')
parser_args, _ = parser.parse_known_args()
return parser_args
if __name__ == '__main__':
args = parse_args()
data = []
with open(args.data_lst) as f:
lines = f.readlines()
if args.shuffle:
np.random.shuffle(lines)
dst_dir = '/'.join(args.dst_path.split('/')[:-1])
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
print('number of samples:', len(lines))
writer = FileWriter(file_name=args.dst_path, shard_num=args.num_shards)
writer.add_schema(seg_schema, "seg_schema")
cnt = 0
for l in lines:
img_path, label_path = l.strip().split(' ')
sample_ = {"file_name": img_path.split('/')[-1]}
with open(os.path.join(args.data_root, img_path), 'rb') as f:
sample_['data'] = f.read()
with open(os.path.join(args.data_root, label_path), 'rb') as f:
sample_['label'] = f.read()
data.append(sample_)
cnt += 1
if cnt % 1000 == 0:
writer.write_raw_data(data)
print('number of samples written:', cnt)
data = []
if data:
writer.write_raw_data(data)
writer.commit()
print('number of samples written:', cnt)

View File

@ -0,0 +1,67 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""""get_multicards_json"""
import os
import sys
def get_multicards_json(server_id):
""" get_multicards_json"""
hccn_configs = open('/etc/hccn.conf', 'r').readlines()
device_ips = {}
for hccn_item in hccn_configs:
hccn_item = hccn_item.strip()
if hccn_item.startswith('address_'):
device_id, device_ip = hccn_item.split('=')
device_id = device_id.split('_')[1]
device_ips[device_id] = device_ip
print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
hccn_table = {'board_id': '0x0000', 'chip_info': '910', 'deploy_mode': 'lab', 'group_count': '1', 'group_list': []}
instance_list = []
usable_dev = ''
for instance_id in range(8):
instance = {'devices': []}
device_id = str(instance_id)
device_ip = device_ips[device_id]
usable_dev += str(device_id)
instance['devices'].append({
'device_id': device_id,
'device_ip': device_ip,
})
instance['rank_id'] = str(instance_id)
instance['server_id'] = server_id
instance_list.append(instance)
hccn_table['group_list'].append({
'device_num': '8',
'server_num': '1',
'group_name': '',
'instance_count': '8',
'instance_list': instance_list,
})
hccn_table['para_plane_nic_location'] = 'device'
hccn_table['para_plane_nic_name'] = []
for instance_id in range(8):
hccn_table['para_plane_nic_name'].append('eth{}'.format(instance_id))
hccn_table['para_plane_nic_num'] = '8'
hccn_table['status'] = 'completed'
import json
table_fn = os.path.join(os.getcwd(), 'rank_table_8p.json')
print(table_fn)
with open(table_fn, 'w') as table_fp:
json.dump(hccn_table, table_fp, indent=4)
host_server_id = sys.argv[1]
get_multicards_json(host_server_id)

View File

@ -0,0 +1,47 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""get_pretrained_model"""
import torch
from mindspore import Tensor, save_checkpoint
def torch2ms(pth_path, ckpt_path):
"""torch2ms"""
pretrained_dict = torch.load(pth_path)
print('--------------------pretrained keys------------------------')
for k in pretrained_dict:
print(k)
print('---------------------torch2ms keys-----------------------')
new_params = []
for k, v in pretrained_dict.items():
if 'fc' in k:
continue
if 'bn' in k or 'downsample.1' in k:
k = k.replace('running_mean', 'moving_mean')
k = k.replace('running_var', 'moving_variance')
k = k.replace('weight', 'gamma')
k = k.replace('bias', 'beta')
k = 'network.resnet.' + k
print(k)
param_dict = {'name': k, 'data': Tensor(v.detach().numpy())}
new_params.append(param_dict)
save_checkpoint(new_params, ckpt_path)
if __name__ == '__main__':
pth = "./resnet101-5d3b4d8f.pth"
ckpt = "./resnet.ckpt"
torch2ms(pth, ckpt)

View File

@ -0,0 +1,214 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train deeplabv3+"""
import os
import argparse
import ast
from mindspore import context
from mindspore.train.model import Model
from mindspore.context import ParallelMode
import mindspore.nn as nn
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.common import set_seed
from src import dataset as data_generator
from src import loss, learning_rates
from src.deeplab_v3plus import DeepLabV3Plus
set_seed(1)
class BuildTrainNetwork(nn.Cell):
def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
def construct(self, input_data, label):
output = self.network(input_data)
net_loss = self.criterion(output, label)
return net_loss
def parse_args():
"""parse_args"""
parser = argparse.ArgumentParser('MindSpore DeepLabV3+ training')
# Ascend or CPU
parser.add_argument('--train_dir', type=str, default='', help='where training log and CKPTs saved')
# dataset
parser.add_argument('--data_file', type=str, default='', help='path and Name of one MindRecord file')
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
parser.add_argument('--crop_size', type=int, default=513, help='crop size')
parser.add_argument('--image_mean', type=list, default=[103.53, 116.28, 123.675], help='image mean')
parser.add_argument('--image_std', type=list, default=[57.375, 57.120, 58.395], help='image std')
parser.add_argument('--min_scale', type=float, default=0.5, help='minimum scale of data argumentation')
parser.add_argument('--max_scale', type=float, default=2.0, help='maximum scale of data argumentation')
parser.add_argument('--ignore_label', type=int, default=255, help='ignore label')
parser.add_argument('--num_classes', type=int, default=21, help='number of classes')
# optimizer
parser.add_argument('--train_epochs', type=int, default=300, help='epoch')
parser.add_argument('--lr_type', type=str, default='cos', help='type of learning rate')
parser.add_argument('--base_lr', type=float, default=0.08, help='base learning rate')
parser.add_argument('--lr_decay_step', type=int, default=40000, help='learning rate decay step')
parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='learning rate decay rate')
parser.add_argument('--loss_scale', type=float, default=3072.0, help='loss scale')
# model
parser.add_argument('--model', type=str, default='DeepLabV3plus_s16', help='select model')
parser.add_argument('--freeze_bn', action='store_true', help='freeze bn')
parser.add_argument('--ckpt_pre_trained', type=str, default='', help='PreTrained model')
# train
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--is_distributed', action='store_true', help='distributed training')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
parser.add_argument('--save_steps', type=int, default=110, help='steps interval for saving')
parser.add_argument('--keep_checkpoint_max', type=int, default=200, help='max checkpoint for saving')
# ModelArts
parser.add_argument('--modelArts_mode', type=ast.literal_eval, default=False,
help='train on modelarts or not, default is False')
parser.add_argument('--train_url', type=str, default='', help='where training log and CKPTs saved')
parser.add_argument('--data_url', type=str, default='', help='the directory path of saved file')
parser.add_argument('--dataset_filename', type=str, default='', help='Name of the MindRecord file')
parser.add_argument('--pretrainedmodel_filename', type=str, default='', help='Name of the pretraining model file')
args, _ = parser.parse_known_args()
return args
def train():
"""train"""
args = parse_args()
if args.device_target == "CPU":
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
else:
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
device_target="Ascend", device_id=int(os.getenv('DEVICE_ID')))
# init multicards training
if args.modelArts_mode:
import moxing as mox
local_data_url = '/cache/data'
local_train_url = '/cache/ckpt'
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
if device_num > 1:
init()
args.rank = get_rank()
args.group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True,
device_num=args.group_size)
local_data_url = os.path.join(local_data_url, str(device_id))
# download dataset from obs to cache
mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_url)
data_file = local_data_url + '/' + args.dataset_filename
ckpt_file = local_data_url + '/' + args.pretrainedmodel_filename
train_dir = local_train_url
else:
if args.is_distributed:
init()
args.rank = get_rank()
args.group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True,
device_num=args.group_size)
data_file = args.data_file
ckpt_file = args.ckpt_pre_trained
train_dir = args.train_dir
# dataset
dataset = data_generator.SegDataset(image_mean=args.image_mean,
image_std=args.image_std,
data_file=data_file,
batch_size=args.batch_size,
crop_size=args.crop_size,
max_scale=args.max_scale,
min_scale=args.min_scale,
ignore_label=args.ignore_label,
num_classes=args.num_classes,
num_readers=2,
num_parallel_calls=4,
shard_id=args.rank,
shard_num=args.group_size)
dataset = dataset.get_dataset(repeat=1)
# network
if args.model == 'DeepLabV3plus_s16':
network = DeepLabV3Plus('train', args.num_classes, 16, args.freeze_bn)
elif args.model == 'DeepLabV3plus_s8':
network = DeepLabV3Plus('train', args.num_classes, 8, args.freeze_bn)
else:
raise NotImplementedError('model [{:s}] not recognized'.format(args.model))
# loss
loss_ = loss.SoftmaxCrossEntropyLoss(args.num_classes, args.ignore_label)
loss_.add_flags_recursive(fp32=True)
train_net = BuildTrainNetwork(network, loss_)
# load pretrained model
if args.ckpt_pre_trained or args.pretrainedmodel_filename:
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(train_net, param_dict)
# optimizer
iters_per_epoch = dataset.get_dataset_size()
total_train_steps = iters_per_epoch * args.train_epochs
if args.lr_type == 'cos':
lr_iter = learning_rates.cosine_lr(args.base_lr, total_train_steps, total_train_steps)
elif args.lr_type == 'poly':
lr_iter = learning_rates.poly_lr(args.base_lr, total_train_steps, total_train_steps, end_lr=0.0, power=0.9)
elif args.lr_type == 'exp':
lr_iter = learning_rates.exponential_lr(args.base_lr, args.lr_decay_step, args.lr_decay_rate,
total_train_steps, staircase=True)
else:
raise ValueError('unknown learning rate type')
opt = nn.Momentum(params=train_net.trainable_params(), learning_rate=lr_iter, momentum=0.9, weight_decay=0.0001,
loss_scale=args.loss_scale)
# loss scale
manager_loss_scale = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
amp_level = "O0" if args.device_target == "CPU" else "O3"
model = Model(train_net, optimizer=opt, amp_level=amp_level, loss_scale_manager=manager_loss_scale)
# callback for saving ckpts
time_cb = TimeMonitor(data_size=iters_per_epoch)
loss_cb = LossMonitor()
cbs = [time_cb, loss_cb]
if args.rank == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=args.save_steps,
keep_checkpoint_max=args.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=args.model, directory=train_dir, config=config_ck)
cbs.append(ckpoint_cb)
model.train(args.train_epochs, dataset, callbacks=cbs, dataset_sink_mode=(args.device_target != "CPU"))
if args.modelArts_mode:
# copy train result from cache to obs
if args.rank == 0:
mox.file.copy_parallel(src_url=local_train_url, dst_url=args.train_url)
if __name__ == '__main__':
train()