forked from mindspore-Ecosystem/mindspore
commit
9f8502598f
Binary file not shown.
After Width: | Height: | Size: 97 KiB |
Binary file not shown.
After Width: | Height: | Size: 52 KiB |
Binary file not shown.
After Width: | Height: | Size: 350 KiB |
|
@ -0,0 +1,293 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [RCAN描述](#RCAN描述)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练](#训练)
|
||||
- [评估](#评估)
|
||||
- [参数配置](#参数配置)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练-1)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估-1)
|
||||
- [模型导出](#模型导出)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [训练性能](#训练性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# RCAN描述
|
||||
|
||||
卷积神经网络(CNN)深度是图像超分辨率(SR)的关键。然而,我们观察到图像SR的更深的网络更难训练。低分辨率的输入和特征包含了丰富的低频信息,这些信息在不同的信道中被平等地对待,从而阻碍了CNNs的表征能力。为了解决这些问题,我们提出了超深剩余信道注意网络(RCAN)。具体地说,我们提出了一种残差中残差(RIR)结构来形成非常深的网络,它由多个具有长跳跃连接的残差组组成。每个剩余组包含一些具有短跳过连接的剩余块。同时,RIR允许通过多跳连接绕过丰富的低频信息,使主网集中学习高频信息。此外,我们提出了一种通道注意机制,通过考虑通道间的相互依赖性,自适应地重新缩放通道特征。大量的实验表明,我们的RCAN与现有的方法相比,具有更好的精确度和视觉效果。
|
||||
![CA](https://gitee.com/bcc2974874275/mindspore/raw/master/model_zoo/research/cv/RCAN/Figs/CA.PNG)
|
||||
通道注意(CA)结构。
|
||||
![RCAB](https://gitee.com/bcc2974874275/mindspore/raw/master/model_zoo/research/cv/RCAN/Figs/RCAB.PNG)
|
||||
剩余通道注意块(RCAB)结构。
|
||||
![RCAN](https://gitee.com/bcc2974874275/mindspore/raw/master/model_zoo/research/cv/RCAN/Figs/RCAN.PNG)
|
||||
本文提出的剩余信道注意网络(RCAN)的体系结构。
|
||||
|
||||
# 数据集
|
||||
|
||||
## 使用的数据集:[Div2k](https://data.vision.ee.ethz.ch/cvl/DIV2K/)
|
||||
|
||||
- 数据集大小:约7.12GB,共900张图像
|
||||
- 训练集:800张图像
|
||||
- 测试集:100张图像
|
||||
- 基准数据集可下载如下:[Set5](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html)、[Set14](https://deepai.org/dataset/set14-super-resolution)、[B100](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/)、[Urban100](http://vllab.ucmerced.edu/wlai24/LapSRN/)。
|
||||
- 数据格式:png文件
|
||||
- 注:数据将在src/data/DIV2K.py中处理。
|
||||
|
||||
```bash
|
||||
DIV2K
|
||||
├── DIV2K_test_LR_bicubic
|
||||
│ ├── X2
|
||||
│ │ ├── 0901x2.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 1000x2.png
|
||||
│ ├── X3
|
||||
│ │ ├── 0901x3.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 1000x3.png
|
||||
│ └── X4
|
||||
│ ├── 0901x4.png
|
||||
│ ├─ ...
|
||||
│ └── 1000x4.png
|
||||
├── DIV2K_test_LR_unknown
|
||||
│ ├── X2
|
||||
│ │ ├── 0901x2.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 1000x2.png
|
||||
│ ├── X3
|
||||
│ │ ├── 0901x3.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 1000x3.png
|
||||
│ └── X4
|
||||
│ ├── 0901x4.png
|
||||
│ ├─ ...
|
||||
│ └── 1000x4.png
|
||||
├── DIV2K_train_HR
|
||||
│ ├── 0001.png
|
||||
│ ├─ ...
|
||||
│ └── 0900.png
|
||||
├── DIV2K_train_LR_bicubic
|
||||
│ ├── X2
|
||||
│ │ ├── 0001x2.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 0900x2.png
|
||||
│ ├── X3
|
||||
│ │ ├── 0001x3.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 0900x3.png
|
||||
│ └── X4
|
||||
│ ├── 0001x4.png
|
||||
│ ├─ ...
|
||||
│ └── 0900x4.png
|
||||
└── DIV2K_train_LR_unknown
|
||||
├── X2
|
||||
│ ├── 0001x2.png
|
||||
│ ├─ ...
|
||||
│ └── 0900x2.png
|
||||
├── X3
|
||||
│ ├── 0001x3.png
|
||||
│ ├─ ...
|
||||
│ └── 0900x3.png
|
||||
└── X4
|
||||
├── 0001x4.png
|
||||
├─ ...
|
||||
└── 0900x4.png
|
||||
```
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 准备Ascend处理器搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```bash
|
||||
├── model_zoo
|
||||
├── README.md // 所有模型相关说明
|
||||
├── RCAN
|
||||
├── scripts
|
||||
│ ├── run_distribute_train.sh // Ascend分布式训练shell脚本
|
||||
│ ├── run_eval.sh // eval验证shell脚本
|
||||
│ ├── run_ascend_standalone.sh // Ascend训练shell脚本
|
||||
├── src
|
||||
│ ├── data
|
||||
│ │ ├──common.py //公共数据集
|
||||
│ │ ├──div2k.py //div2k数据集
|
||||
│ │ ├──srdata.py //所有数据集
|
||||
│ ├── rcan_model.py //RCAN网络
|
||||
│ ├── metrics.py //PSNR,SSIM计算器
|
||||
│ ├── args.py //超参数
|
||||
├── train.py //训练脚本
|
||||
├── eval.py //评估脚本
|
||||
├── export.py //模型导出
|
||||
├── README.md // 自述文件
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
### 训练
|
||||
|
||||
```bash
|
||||
用法:python train.py [--device_target][--dir_data]
|
||||
[--ckpt_path][--test_every][--task_id]
|
||||
选项:
|
||||
--device_target 训练后端类型,Ascend,默认为Ascend。
|
||||
--dir_data 数据集存储路径。
|
||||
--ckpt_path 存放检查点的路径。
|
||||
--test_every 每N批进行一次试验。
|
||||
--task_id 任务ID。
|
||||
```
|
||||
|
||||
### 评估
|
||||
|
||||
```bash
|
||||
用法:python eval.py [--device_target][--dir_data]
|
||||
[--task_id][--scale][--data_test]
|
||||
[--ckpt_save_path]
|
||||
|
||||
选项:
|
||||
--device_target 评估后端类型,Ascend。
|
||||
--dir_data 数据集路径。
|
||||
--task_id 任务id。
|
||||
--scale 超分倍数。
|
||||
--data_test 测试数据集名字。
|
||||
--ckpt_save_path 检查点路径。
|
||||
```
|
||||
|
||||
## 参数配置
|
||||
|
||||
在args.py中可以同时配置训练参数和评估参数。
|
||||
|
||||
- RCAN配置,div2k数据集
|
||||
|
||||
```bash
|
||||
"lr": 0.0001, # 学习率
|
||||
"epochs": 500, # 训练轮次数
|
||||
"batch_size": 16, # 输入张量的批次大小
|
||||
"weight_decay": 0, # 权重衰减
|
||||
"loss_scale": 1024, # 损失放大
|
||||
"buffer_size": 10, # 混洗缓冲区大小
|
||||
"init_loss_scale":65536, # 比例因子
|
||||
"betas":(0.9, 0.999), # ADAM beta
|
||||
"weight_decay":0, # 权重衰减
|
||||
"num_layers":4, # 层数
|
||||
"test_every":4000, # 每N批进行一次试验
|
||||
"n_resgroups":10, # 残差组数
|
||||
"reduction":16, # 特征映射数减少
|
||||
"patch_size":48, # 输出块大小
|
||||
"scale":'2', # 超分辨率比例尺
|
||||
"task_id":0, # 任务id
|
||||
"n_colors":3, # 颜色通道数
|
||||
"n_resblocks":20, # 残差块数
|
||||
"n_feats":64, # 特诊图数量
|
||||
"res_scale":1, # residual scaling
|
||||
```
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 训练
|
||||
|
||||
#### Ascend处理器环境运行RCAN
|
||||
|
||||
- 单设备训练(1p)
|
||||
- 二倍超分task_id 0
|
||||
- 三倍超分task_id 1
|
||||
- 四倍超分task_id 2
|
||||
|
||||
```bash
|
||||
sh scripts/run_ascend_distribute.sh [TRAIN_DATA_DIR]
|
||||
```
|
||||
|
||||
- 分布式训练
|
||||
- 二倍超分task_id 0
|
||||
- 三倍超分task_id 1
|
||||
- 四倍超分task_id 2
|
||||
|
||||
```bash
|
||||
sh scripts/run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
|
||||
```
|
||||
|
||||
- 分布式训练需要提前创建JSON格式的HCCL配置文件。具体操作,参见:<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 评估
|
||||
|
||||
- 评估过程如下,需要指定数据集类型为“Set5”或“B100”。
|
||||
|
||||
```bash
|
||||
sh scripts/eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
|
||||
```
|
||||
|
||||
- 上述python命令在后台运行,可通过`eval.log`文件查看结果。
|
||||
|
||||
## 模型导出
|
||||
|
||||
```bash
|
||||
用法:python export.py [--batch_size] [--ckpt_path] [--file_format]
|
||||
选项:
|
||||
--batch_size 输入张量的批次大小。
|
||||
--ckpt_path 检查点路径。
|
||||
--file_format 可选 ['MINDIR', 'AIR', 'ONNX'], 默认['MINDIR']。
|
||||
```
|
||||
|
||||
- FILE_FORMAT 可选 ['MINDIR', 'AIR', 'ONNX'], 默认['MINDIR']。
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 训练性能
|
||||
|
||||
| 参数 | RCAN(Ascend) |
|
||||
| -------------------------- | ---------------------------------------------- |
|
||||
| 模型版本 | RCAN |
|
||||
| 资源 | Ascend 910; |
|
||||
| 上传日期 | 2021-06-30 |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 |DIV2K |
|
||||
| 训练参数 |epoch=500, batch_size = 16, lr=0.0001 |
|
||||
| 优化器 | Adam |
|
||||
| 损失函数 | L1loss |
|
||||
| 输出 | 超分辨率图片 |
|
||||
| 损失 | |
|
||||
| 速度 | 8卡:205毫秒/步 |
|
||||
| 总时长 | 8卡:14.74小时 |
|
||||
| 调优检查点 | 0.2 GB(.ckpt 文件) |
|
||||
| 脚本 |[RCAN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/RCAN) | |
|
||||
|
||||
### 评估性能
|
||||
|
||||
| 参数 | RCAN(Ascend) |
|
||||
| ------------------- | --------------------------- |
|
||||
| 模型版本 | RCAN |
|
||||
| 资源 | Ascend 910 |
|
||||
| 上传日期 | 2021-07-11 |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | Set5,B100 |
|
||||
| batch_size | 1 |
|
||||
| 输出 | 超分辨率图片 |
|
||||
| 准确率 | 单卡:Set5: 38.15/B100:32.28 |
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""eval script"""
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.args import args
|
||||
import src.rcan_model as rcan
|
||||
from src.data.srdata import SRData
|
||||
from src.metrics import calc_psnr, quantize, calc_ssim
|
||||
from src.data.div2k import DIV2K
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
|
||||
context.set_context(max_call_depth=10000)
|
||||
def eval_net():
|
||||
"""eval"""
|
||||
if args.epochs == 0:
|
||||
args.epochs = 1e8
|
||||
for arg in vars(args):
|
||||
if vars(args)[arg] == 'True':
|
||||
vars(args)[arg] = True
|
||||
elif vars(args)[arg] == 'False':
|
||||
vars(args)[arg] = False
|
||||
if args.data_test[0] == 'DIV2K':
|
||||
train_dataset = DIV2K(args, name=args.data_test, train=False, benchmark=False)
|
||||
else:
|
||||
train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False)
|
||||
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
|
||||
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
|
||||
net_m = rcan.RCAN(args)
|
||||
if args.ckpt_path:
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(net_m, param_dict)
|
||||
net_m.set_train(False)
|
||||
|
||||
print('load mindspore net successfully.')
|
||||
num_imgs = train_de_dataset.get_dataset_size()
|
||||
psnrs = np.zeros((num_imgs, 1))
|
||||
ssims = np.zeros((num_imgs, 1))
|
||||
for batch_idx, imgs in enumerate(train_loader):
|
||||
lr = imgs['LR']
|
||||
hr = imgs['HR']
|
||||
lr = Tensor(lr, mstype.float32)
|
||||
pred = net_m(lr)
|
||||
pred_np = pred.asnumpy()
|
||||
pred_np = quantize(pred_np, 255)
|
||||
psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0)
|
||||
pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0)
|
||||
hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0)
|
||||
ssim = calc_ssim(pred_np, hr, args.scale[0])
|
||||
print("current psnr: ", psnr)
|
||||
print("current ssim: ", ssim)
|
||||
psnrs[batch_idx, 0] = psnr
|
||||
ssims[batch_idx, 0] = ssim
|
||||
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
|
||||
print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
time_start = time.time()
|
||||
print("Start eval function!")
|
||||
eval_net()
|
||||
time_end = time.time()
|
||||
print('eval_time: %f' % (time_end - time_start))
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export net together with checkpoint into air/mindir/onnx models"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from src.args import args as arg
|
||||
from src.rcan_model import RCAN
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='rcan export')
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_path", type=str, required=True, help="path of checkpoint file")
|
||||
parser.add_argument("--file_name", type=str, default="rcan", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, default="MINDIR", choices=['MINDIR', 'AIR', 'ONNX'], help="file format")
|
||||
args_1 = parser.parse_args()
|
||||
|
||||
|
||||
def run_export(args):
|
||||
""" export """
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||
net = RCAN(arg)
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
print('load mindspore net and checkpoint successfully.')
|
||||
inputs = Tensor(np.zeros([args.batch_size, 3, 678, 1020], np.float32))
|
||||
export(net, inputs, file_name=args.file_name, file_format=args.file_format)
|
||||
print('export successfully!')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_export(args_1)
|
|
@ -0,0 +1,69 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
if [ ! -f $PATH1 ]; then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]; then
|
||||
echo "error: TRAIN_DATA_DIR=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
|
||||
nohup python train.py \
|
||||
--batch_size 16 \
|
||||
--lr 1e-4 \
|
||||
--scale 2+3+4 \
|
||||
--task_id 0 \
|
||||
--dir_data $PATH2 \
|
||||
--epochs 500 \
|
||||
--test_every 4000 \
|
||||
--patch_size 48 > train.log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,73 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
#!/bin/bash
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]; then
|
||||
echo "Usage: sh run_standalone_train.sh [TRAIN_DATA_DIR]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
if [ -d "train" ]; then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
|
||||
env >env.log
|
||||
|
||||
nohup python train.py \
|
||||
--batch_size 16 \
|
||||
--lr 1e-4 \
|
||||
--scale 2+3+4 \
|
||||
--task_id 0 \
|
||||
--dir_data $PATH1 \
|
||||
--epochs 500 \
|
||||
--test_every 4000 \
|
||||
--patch_size 48 > train.log 2>&1 &
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
DATASET_TYPE=$3
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: TEST_DATA_DIR=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]; then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -d "eval" ]; then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
echo "start evaluation ..."
|
||||
|
||||
python eval.py \
|
||||
--dir_data=${PATH1} \
|
||||
--batch_size 1 \
|
||||
--test_only \
|
||||
--ext "img" \
|
||||
--data_test=${DATASET_TYPE} \
|
||||
--ckpt_path=${PATH2} \
|
||||
--task_id 0 \
|
||||
--scale 2 > eval.log 2>&1 &
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""args"""
|
||||
import argparse
|
||||
import ast
|
||||
|
||||
parser = argparse.ArgumentParser(description='RCAN')
|
||||
|
||||
# Hardware specifications
|
||||
parser.add_argument('--seed', type=int, default=1,
|
||||
help='random seed')
|
||||
|
||||
|
||||
# Data specifications
|
||||
parser.add_argument('--dir_data', type=str, default='/cache/data/',
|
||||
help='dataset directory')
|
||||
parser.add_argument('--data_train', type=str, default='DIV2K',
|
||||
help='train dataset name')
|
||||
parser.add_argument('--data_test', type=str, default='DIV2K',
|
||||
help='test dataset name')
|
||||
parser.add_argument('--data_range', type=str, default='1-800/801-810',
|
||||
help='train/test data range')
|
||||
parser.add_argument('--ext', type=str, default='sep',
|
||||
help='dataset file extension')
|
||||
parser.add_argument('--scale', type=str, default='4',
|
||||
help='super resolution scale')
|
||||
parser.add_argument('--patch_size', type=int, default=48,
|
||||
help='output patch size')
|
||||
parser.add_argument('--rgb_range', type=int, default=255,
|
||||
help='maximum value of RGB')
|
||||
parser.add_argument('--n_colors', type=int, default=3,
|
||||
help='number of color channels to use')
|
||||
parser.add_argument('--no_augment', action='store_true',
|
||||
help='do not use data augmentation')
|
||||
|
||||
# Model specifications
|
||||
parser.add_argument('--model', default='RCAN',
|
||||
help='model name')
|
||||
parser.add_argument('--act', type=str, default='relu',
|
||||
help='activation function')
|
||||
parser.add_argument('--n_resblocks', type=int, default=20,
|
||||
help='number of residual blocks')
|
||||
parser.add_argument('--n_feats', type=int, default=64,
|
||||
help='number of feature maps')
|
||||
parser.add_argument('--res_scale', type=float, default=1,
|
||||
help='residual scaling')
|
||||
|
||||
|
||||
# Option for Residual channel attention network (RCAN)
|
||||
parser.add_argument('--n_resgroups', type=int, default=10,
|
||||
help='number of residual groups')
|
||||
parser.add_argument('--reduction', type=int, default=16,
|
||||
help='number of feature maps reduction')
|
||||
|
||||
# Training specifications
|
||||
parser.add_argument('--test_every', type=int, default=4000,
|
||||
help='do test per every N batches')
|
||||
parser.add_argument('--epochs', type=int, default=1000,
|
||||
help='number of epochs to train')
|
||||
parser.add_argument('--batch_size', type=int, default=16,
|
||||
help='input batch size for training')
|
||||
parser.add_argument('--test_only', action='store_true',
|
||||
help='set this option to test the model')
|
||||
|
||||
|
||||
# Optimization specifications
|
||||
parser.add_argument('--lr', type=float, default=1e-5,
|
||||
help='learning rate')
|
||||
parser.add_argument('--init_loss_scale', type=float, default=65536.,
|
||||
help='scaling factor')
|
||||
parser.add_argument('--decay', type=str, default='200',
|
||||
help='learning rate decay type')
|
||||
parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
|
||||
help='ADAM beta')
|
||||
parser.add_argument('--epsilon', type=float, default=1e-8,
|
||||
help='ADAM epsilon for numerical stability')
|
||||
parser.add_argument('--weight_decay', type=float, default=0,
|
||||
help='weight decay')
|
||||
parser.add_argument('--gclip', type=float, default=0,
|
||||
help='gradient clipping threshold (0 = no clipping)')
|
||||
|
||||
# ckpt specifications
|
||||
parser.add_argument('--ckpt_save_path', type=str, default='./ckpt/',
|
||||
help='path to save ckpt')
|
||||
parser.add_argument('--ckpt_save_interval', type=int, default=10,
|
||||
help='save ckpt frequency, unit is epoch')
|
||||
parser.add_argument('--ckpt_save_max', type=int, default=100,
|
||||
help='max number of saved ckpt')
|
||||
parser.add_argument('--ckpt_path', type=str, default='',
|
||||
help='path of saved ckpt')
|
||||
|
||||
# Task
|
||||
parser.add_argument('--task_id', type=int, default=0)
|
||||
|
||||
# ModelArts
|
||||
parser.add_argument('--modelArts_mode', type=ast.literal_eval, default=False,
|
||||
help='train on modelarts or not, default is False')
|
||||
parser.add_argument('--data_url', type=str, default='', help='the directory path of saved file')
|
||||
|
||||
|
||||
args, unparsed = parser.parse_known_args()
|
||||
|
||||
args.scale = [int(x) for x in args.scale.split("+")]
|
||||
args.data_train = args.data_train.split('+')
|
||||
args.data_test = args.data_test.split('+')
|
||||
|
||||
if args.epochs == 0:
|
||||
args.epochs = 1e8
|
||||
|
||||
for arg in vars(args):
|
||||
if vars(args)[arg] == 'True':
|
||||
vars(args)[arg] = True
|
||||
elif vars(args)[arg] == 'False':
|
||||
vars(args)[arg] = False
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""common"""
|
||||
import random
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_patch(*args, patch_size=96, scale=2, input_large=False):
|
||||
"""get_patch"""
|
||||
ih, iw = args[0].shape[:2]
|
||||
|
||||
tp = patch_size
|
||||
ip = tp // scale
|
||||
|
||||
ix = random.randrange(0, iw - ip + 1)
|
||||
iy = random.randrange(0, ih - ip + 1)
|
||||
|
||||
if not input_large:
|
||||
tx, ty = scale * ix, scale * iy
|
||||
else:
|
||||
tx, ty = ix, iy
|
||||
|
||||
ret = [args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]]
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def set_channel(*args, n_channels=3):
|
||||
"""set_channel"""
|
||||
def _set_channel(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
|
||||
c = img.shape[2]
|
||||
if n_channels == 3 and c == 1:
|
||||
img = np.concatenate([img] * n_channels, 2)
|
||||
|
||||
return img[:, :, :n_channels]
|
||||
|
||||
return [_set_channel(a) for a in args]
|
||||
|
||||
|
||||
def np2Tensor(*args, rgb_range=255):
|
||||
""" np2Tensor"""
|
||||
def _np2Tensor(img):
|
||||
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
|
||||
input_data = np_transpose.astype(np.float32)
|
||||
output = input_data * (rgb_range / 255)
|
||||
return output
|
||||
return [_np2Tensor(a) for a in args]
|
||||
|
||||
|
||||
def augment(*args, hflip=True, rot=True):
|
||||
"""augment("""
|
||||
hflip = hflip and random.random() < 0.5
|
||||
vflip = rot and random.random() < 0.5
|
||||
rot90 = rot and random.random() < 0.5
|
||||
|
||||
def _augment(img):
|
||||
"""augment"""
|
||||
if hflip:
|
||||
img = img[:, ::-1, :]
|
||||
if vflip:
|
||||
img = img[::-1, :, :]
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
return [_augment(a) for a in args]
|
||||
|
||||
|
||||
def search(root, target="JPEG"):
|
||||
"""search"""
|
||||
item_list = []
|
||||
items = os.listdir(root)
|
||||
for item in items:
|
||||
path = os.path.join(root, item)
|
||||
if os.path.isdir(path):
|
||||
item_list.extend(search(path, target))
|
||||
elif path.split('/')[-1].startswith(target):
|
||||
item_list.append(path)
|
||||
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
|
||||
item_list.append(path)
|
||||
return item_list
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""div2k"""
|
||||
import os
|
||||
from src.data.srdata import SRData
|
||||
|
||||
|
||||
class DIV2K(SRData):
|
||||
"""DIV2K"""
|
||||
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
|
||||
self.dir_hr = None
|
||||
self.dir_lr = None
|
||||
data_range = [r.split('-') for r in args.data_range.split('/')]
|
||||
if train:
|
||||
data_range = data_range[0]
|
||||
else:
|
||||
if args.test_only and len(data_range) == 1:
|
||||
data_range = data_range[0]
|
||||
else:
|
||||
data_range = data_range[1]
|
||||
|
||||
self.begin, self.end = list(map(int, data_range))
|
||||
super(DIV2K, self).__init__(args, name=name, train=train, benchmark=benchmark)
|
||||
|
||||
def _scan(self):
|
||||
names_hr, names_lr = super(DIV2K, self)._scan()
|
||||
names_hr = names_hr[self.begin - 1:self.end]
|
||||
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
|
||||
|
||||
return names_hr, names_lr
|
||||
|
||||
def _set_filesystem(self, dir_data):
|
||||
super(DIV2K, self)._set_filesystem(dir_data)
|
||||
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
|
||||
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""""srdata"""
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
import pickle
|
||||
import imageio
|
||||
from src.data import common
|
||||
from PIL import ImageFile
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
class SRData:
|
||||
"""srdata"""
|
||||
|
||||
def __init__(self, args, name='', train=True, benchmark=False):
|
||||
self.derain_lr_test = None
|
||||
self.derain_hr_test = None
|
||||
self.deblur_lr_test = None
|
||||
self.deblur_hr_test = None
|
||||
self.args = args
|
||||
self.name = name
|
||||
self.train = train
|
||||
self.split = 'train' if train else 'test'
|
||||
self.do_eval = True
|
||||
self.benchmark = benchmark
|
||||
self.input_large = (args.model == 'VDSR')
|
||||
self.scale = args.scale
|
||||
self.idx_scale = 0
|
||||
self._set_filesystem(args.dir_data)
|
||||
self._set_img(args)
|
||||
if train:
|
||||
self._repeat(args)
|
||||
|
||||
def _set_img(self, args):
|
||||
"""set_img"""
|
||||
if args.ext.find('img') < 0:
|
||||
path_bin = os.path.join(self.apath, 'bin')
|
||||
os.makedirs(path_bin, exist_ok=True)
|
||||
list_hr, list_lr = self._scan()
|
||||
if args.ext.find('img') >= 0 or self.benchmark:
|
||||
self.images_hr, self.images_lr = list_hr, list_lr
|
||||
elif args.ext.find('sep') >= 0:
|
||||
os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True)
|
||||
for s in self.scale:
|
||||
if s == 1:
|
||||
os.makedirs(os.path.join(self.dir_hr), exist_ok=True)
|
||||
else:
|
||||
os.makedirs(
|
||||
os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True)
|
||||
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
|
||||
for h in list_hr:
|
||||
b = h.replace(self.apath, path_bin)
|
||||
b = b.replace(self.ext[0], '.pt')
|
||||
self.images_hr.append(b)
|
||||
self._check_and_load(args.ext, h, b, verbose=True)
|
||||
for i, ll in enumerate(list_lr):
|
||||
for l in ll:
|
||||
b = l.replace(self.apath, path_bin)
|
||||
b = b.replace(self.ext[1], '.pt')
|
||||
self.images_lr[i].append(b)
|
||||
self._check_and_load(args.ext, l, b, verbose=True)
|
||||
|
||||
def _repeat(self, args):
|
||||
"""repeat"""
|
||||
n_patches = args.batch_size * args.test_every
|
||||
n_images = len(args.data_train) * len(self.images_hr)
|
||||
if n_images == 0:
|
||||
self.repeat = 0
|
||||
else:
|
||||
self.repeat = max(n_patches // n_images, 1)
|
||||
|
||||
def _scan(self):
|
||||
"""_scan"""
|
||||
names_hr = sorted(
|
||||
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])))
|
||||
names_lr = [[] for _ in self.scale]
|
||||
for f in names_hr:
|
||||
filename, _ = os.path.splitext(os.path.basename(f))
|
||||
for si, s in enumerate(self.scale):
|
||||
if s != 1:
|
||||
scale = s
|
||||
names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \
|
||||
.format(s, filename, scale, self.ext[1])))
|
||||
for si, s in enumerate(self.scale):
|
||||
if s == 1:
|
||||
names_lr[si] = names_hr
|
||||
return names_hr, names_lr
|
||||
|
||||
def _set_filesystem(self, dir_data):
|
||||
"""set_filesystem"""
|
||||
self.apath = os.path.join(dir_data, self.name[0])
|
||||
self.dir_hr = os.path.join(self.apath, 'HR')
|
||||
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
|
||||
self.ext = ('.png', '.png')
|
||||
|
||||
def _check_and_load(self, ext, img, f, verbose=True):
|
||||
"""check_and_load"""
|
||||
if not os.path.isfile(f) or ext.find('reset') >= 0:
|
||||
if verbose:
|
||||
print('Making a binary: {}'.format(f))
|
||||
with open(f, 'wb') as _f:
|
||||
pickle.dump(imageio.imread(img), _f)
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
def __getitem__(self, idx):
|
||||
"""get item"""
|
||||
lr, hr, filename = self._load_file(idx)
|
||||
pair = self.get_patch(lr, hr)
|
||||
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
|
||||
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
||||
return pair_t[0], pair_t[1]
|
||||
|
||||
def __len__(self):
|
||||
"""length of hr"""
|
||||
if self.train:
|
||||
return len(self.images_hr) * self.repeat
|
||||
return len(self.images_hr)
|
||||
|
||||
def _get_index(self, idx):
|
||||
"""get_index"""
|
||||
if self.train:
|
||||
return idx % len(self.images_hr)
|
||||
return idx
|
||||
|
||||
def _load_file_hr(self, idx):
|
||||
"""load_file_hr"""
|
||||
idx = self._get_index(idx)
|
||||
f_hr = self.images_hr[idx]
|
||||
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
||||
if self.args.ext == 'img' or self.benchmark:
|
||||
hr = imageio.imread(f_hr)
|
||||
elif self.args.ext.find('sep') >= 0:
|
||||
with open(f_hr, 'rb') as _f:
|
||||
hr = pickle.load(_f)
|
||||
return hr, filename
|
||||
|
||||
def _load_file(self, idx):
|
||||
"""load_file"""
|
||||
idx = self._get_index(idx)
|
||||
# print(idx,flush=True)
|
||||
f_hr = self.images_hr[idx]
|
||||
f_lr = self.images_lr[self.idx_scale][idx]
|
||||
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
||||
if self.args.ext == 'img' or self.benchmark:
|
||||
hr = imageio.imread(f_hr)
|
||||
lr = imageio.imread(f_lr)
|
||||
elif self.args.ext.find('sep') >= 0:
|
||||
with open(f_hr, 'rb') as _f:
|
||||
hr = pickle.load(_f)
|
||||
with open(f_lr, 'rb') as _f:
|
||||
lr = pickle.load(_f)
|
||||
return lr, hr, filename
|
||||
|
||||
def get_patch_hr(self, hr):
|
||||
"""get_patch_hr"""
|
||||
if self.train:
|
||||
hr = self.get_patch_img_hr(hr, patch_size=self.args.patch_size, scale=1)
|
||||
return hr
|
||||
|
||||
def get_patch_img_hr(self, img, patch_size=96, scale=2):
|
||||
"""get_patch_img_hr"""
|
||||
ih, iw = img.shape[:2]
|
||||
tp = patch_size
|
||||
ip = tp // scale
|
||||
ix = random.randrange(0, iw - ip + 1)
|
||||
iy = random.randrange(0, ih - ip + 1)
|
||||
ret = img[iy:iy + ip, ix:ix + ip, :]
|
||||
return ret
|
||||
|
||||
def get_patch(self, lr, hr):
|
||||
"""get_patch"""
|
||||
scale = self.scale[self.idx_scale]
|
||||
if self.train:
|
||||
lr, hr = common.get_patch(
|
||||
lr, hr,
|
||||
patch_size=self.args.patch_size * scale,
|
||||
scale=scale)
|
||||
if not self.args.no_augment:
|
||||
lr, hr = common.augment(lr, hr)
|
||||
else:
|
||||
ih, iw = lr.shape[:2]
|
||||
hr = hr[0:ih * scale, 0:iw * scale]
|
||||
return lr, hr
|
||||
|
||||
def set_scale(self, idx_scale):
|
||||
"""set_scale"""
|
||||
if not self.input_large:
|
||||
self.idx_scale = idx_scale
|
||||
else:
|
||||
self.idx_scale = random.randint(0, len(self.scale) - 1)
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""metrics"""
|
||||
import math
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def quantize(img, rgb_range):
|
||||
"""quantize image range to 0-255"""
|
||||
pixel_range = 255 / rgb_range
|
||||
img = np.multiply(img, pixel_range)
|
||||
img = np.clip(img, 0, 255)
|
||||
img = np.round(img) / pixel_range
|
||||
return img
|
||||
|
||||
|
||||
def calc_psnr(sr, hr, scale, rgb_range):
|
||||
"""calculate psnr"""
|
||||
hr = np.float32(hr)
|
||||
sr = np.float32(sr)
|
||||
diff = (sr - hr) / rgb_range
|
||||
gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256
|
||||
diff = np.multiply(diff, gray_coeffs).sum(1)
|
||||
if hr.size == 1:
|
||||
return 0
|
||||
|
||||
shave = scale
|
||||
valid = diff[..., shave:-shave, shave:-shave]
|
||||
mse = np.mean(pow(valid, 2))
|
||||
return -10 * math.log10(mse)
|
||||
|
||||
|
||||
def rgb2ycbcr(img, y_only=True):
|
||||
"""from rgb space to ycbcr space"""
|
||||
img.astype(np.float32)
|
||||
if y_only:
|
||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
||||
return rlt
|
||||
|
||||
|
||||
def calc_ssim(img1, img2, scale):
|
||||
"""calculate ssim"""
|
||||
def ssim(img1, img2):
|
||||
"""calculate ssim"""
|
||||
C1 = (0.01 * 255) ** 2
|
||||
C2 = (0.03 * 255) ** 2
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1 ** 2
|
||||
mu2_sq = mu2 ** 2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||
(sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
|
||||
border = scale
|
||||
img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 256.0 + 16.0
|
||||
img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 256.0 + 16.0
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError('Input images must have the same dimensions.')
|
||||
h, w = img1.shape[:2]
|
||||
img1_y = img1_y[border:h - border, border:w - border]
|
||||
img2_y = img2_y[border:h - border, border:w - border]
|
||||
|
||||
if img1_y.ndim == 2:
|
||||
return ssim(img1_y, img2_y)
|
||||
if img1.ndim == 3:
|
||||
if img1.shape[2] == 3:
|
||||
ssims = []
|
||||
for _ in range(3):
|
||||
ssims.append(ssim(img1, img2))
|
||||
|
||||
return np.array(ssims).mean()
|
||||
if img1.shape[2] == 1:
|
||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
||||
|
||||
raise ValueError('Wrong input image dimensions.')
|
|
@ -0,0 +1,228 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""rcan"""
|
||||
import math
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import Tensor, Parameter
|
||||
|
||||
|
||||
def default_conv(in_channels, out_channels, kernel_size, has_bias=True):
|
||||
"""rcan"""
|
||||
return nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size,
|
||||
padding=(kernel_size // 2), has_bias=has_bias, pad_mode='pad')
|
||||
|
||||
|
||||
class MeanShift(nn.Conv2d):
|
||||
"""rcan"""
|
||||
def __init__(self,
|
||||
rgb_range,
|
||||
rgb_mean=(0.4488, 0.4371, 0.4040),
|
||||
rgb_std=(1.0, 1.0, 1.0),
|
||||
sign=-1):
|
||||
"""rcan"""
|
||||
super(MeanShift, self).__init__(3, 3, kernel_size=1)
|
||||
self.reshape = P.Reshape()
|
||||
self.eye = P.Eye()
|
||||
std = Tensor(rgb_std, mstype.float32)
|
||||
self.weight.set_data(
|
||||
self.reshape(self.eye(3, 3, mstype.float32), (3, 3, 1, 1)) / self.reshape(std, (3, 1, 1, 1)))
|
||||
self.weight.requires_grad = False
|
||||
self.bias = Parameter(
|
||||
sign * rgb_range * Tensor(rgb_mean, mstype.float32) / std, name='bias', requires_grad=False)
|
||||
self.has_bias = True
|
||||
|
||||
|
||||
def _pixelsf_(x, scale):
|
||||
"""rcan"""
|
||||
n, c, ih, iw = x.shape
|
||||
oh = ih * scale
|
||||
ow = iw * scale
|
||||
oc = c // (scale ** 2)
|
||||
output = P.Transpose()(x, (0, 2, 1, 3))
|
||||
output = P.Reshape()(output, (n, ih, oc * scale, scale, iw))
|
||||
output = P.Transpose()(output, (0, 1, 2, 4, 3))
|
||||
output = P.Reshape()(output, (n, ih, oc, scale, ow))
|
||||
output = P.Transpose()(output, (0, 2, 1, 3, 4))
|
||||
output = P.Reshape()(output, (n, oc, oh, ow))
|
||||
return output
|
||||
|
||||
|
||||
class SmallUpSampler(nn.Cell):
|
||||
"""rcan"""
|
||||
def __init__(self, conv, upsize, n_feats, has_bias=True):
|
||||
"""rcan"""
|
||||
super(SmallUpSampler, self).__init__()
|
||||
self.conv = conv(n_feats, upsize * upsize * n_feats, 3, has_bias)
|
||||
self.reshape = P.Reshape()
|
||||
self.upsize = upsize
|
||||
self.pixelsf = _pixelsf_
|
||||
|
||||
def construct(self, x):
|
||||
"""rcan"""
|
||||
x = self.conv(x)
|
||||
output = self.pixelsf(x, self.upsize)
|
||||
return output
|
||||
|
||||
|
||||
class Upsampler(nn.Cell):
|
||||
"""rcan"""
|
||||
def __init__(self, conv, scale, n_feats, has_bias=True):
|
||||
"""rcan"""
|
||||
super(Upsampler, self).__init__()
|
||||
m = []
|
||||
if (scale & (scale - 1)) == 0:
|
||||
for _ in range(int(math.log(scale, 2))):
|
||||
m.append(SmallUpSampler(conv, 2, n_feats, has_bias=has_bias))
|
||||
elif scale == 3:
|
||||
m.append(SmallUpSampler(conv, 3, n_feats, has_bias=has_bias))
|
||||
self.net = nn.SequentialCell(m)
|
||||
|
||||
def construct(self, x):
|
||||
"""rcan"""
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class AdaptiveAvgPool2d(nn.Cell):
|
||||
"""rcan"""
|
||||
def __init__(self):
|
||||
"""rcan"""
|
||||
super().__init__()
|
||||
self.ReduceMean = ops.ReduceMean(keep_dims=True)
|
||||
|
||||
def construct(self, x):
|
||||
"""rcan"""
|
||||
return self.ReduceMean(x, 0)
|
||||
|
||||
|
||||
class CALayer(nn.Cell):
|
||||
"""rcan"""
|
||||
def __init__(self, channel, reduction=16):
|
||||
"""rcan"""
|
||||
super(CALayer, self).__init__()
|
||||
# global average pooling: feature --> point
|
||||
self.avg_pool = AdaptiveAvgPool2d()
|
||||
# feature channel downscale and upscale --> channel weight
|
||||
self.conv_du = nn.SequentialCell([
|
||||
nn.Conv2d(channel, channel // reduction, 1, padding=0, has_bias=True, pad_mode='pad'),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(channel // reduction, channel, 1, padding=0, has_bias=True, pad_mode='pad'),
|
||||
nn.Sigmoid()
|
||||
])
|
||||
|
||||
def construct(self, x):
|
||||
"""rcan"""
|
||||
y = self.avg_pool(x)
|
||||
y = self.conv_du(y)
|
||||
return x * y
|
||||
|
||||
|
||||
class RCAB(nn.Cell):
|
||||
"""rcan"""
|
||||
def __init__(self, conv, n_feat, kernel_size, reduction, has_bias=True
|
||||
, bn=False, act=nn.ReLU(), res_scale=1):
|
||||
"""rcan"""
|
||||
super(RCAB, self).__init__()
|
||||
self.modules_body = []
|
||||
for i in range(2):
|
||||
self.modules_body.append(conv(n_feat, n_feat, kernel_size, has_bias=has_bias))
|
||||
if bn: self.modules_body.append(nn.BatchNorm2d(n_feat))
|
||||
if i == 0: self.modules_body.append(act)
|
||||
self.modules_body.append(CALayer(n_feat, reduction))
|
||||
self.body = nn.SequentialCell(*self.modules_body)
|
||||
self.res_scale = res_scale
|
||||
|
||||
def construct(self, x):
|
||||
"""rcan"""
|
||||
res = self.body(x)
|
||||
res += x
|
||||
return res
|
||||
|
||||
|
||||
class ResidualGroup(nn.Cell):
|
||||
"""rcan"""
|
||||
def __init__(self, conv, n_feat, kernel_size, reduction, n_resblocks):
|
||||
"""rcan"""
|
||||
super(ResidualGroup, self).__init__()
|
||||
modules_body = []
|
||||
modules_body = [
|
||||
RCAB(
|
||||
conv, n_feat, kernel_size, reduction, has_bias=True, bn=False, act=nn.ReLU(), res_scale=1) \
|
||||
for _ in range(n_resblocks)]
|
||||
modules_body.append(conv(n_feat, n_feat, kernel_size))
|
||||
self.body = nn.SequentialCell(*modules_body)
|
||||
|
||||
def construct(self, x):
|
||||
"""rcan"""
|
||||
res = self.body(x)
|
||||
res += x
|
||||
return res
|
||||
|
||||
|
||||
class RCAN(nn.Cell):
|
||||
"""rcan"""
|
||||
def __init__(self, args, conv=default_conv):
|
||||
"""rcan"""
|
||||
super(RCAN, self).__init__()
|
||||
|
||||
n_resgroups = args.n_resgroups
|
||||
n_resblocks = args.n_resblocks
|
||||
n_feats = args.n_feats
|
||||
kernel_size = 3
|
||||
reduction = args.reduction
|
||||
idx = args.task_id
|
||||
scale = args.scale[idx]
|
||||
self.dytpe = mstype.float16
|
||||
|
||||
# RGB mean for DIV2K
|
||||
rgb_mean = (0.4488, 0.4371, 0.4040)
|
||||
rgb_std = (1.0, 1.0, 1.0)
|
||||
|
||||
self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std).to_float(self.dytpe)
|
||||
|
||||
# define head module
|
||||
modules_head = conv(args.n_colors, n_feats, kernel_size).to_float(self.dytpe)
|
||||
|
||||
# define body module
|
||||
modules_body = [
|
||||
ResidualGroup(
|
||||
conv, n_feats, kernel_size, reduction, n_resblocks=n_resblocks).to_float(self.dytpe) \
|
||||
for _ in range(n_resgroups)]
|
||||
|
||||
modules_body.append(conv(n_feats, n_feats, kernel_size).to_float(self.dytpe))
|
||||
|
||||
# define tail module
|
||||
modules_tail = [
|
||||
Upsampler(conv, scale, n_feats).to_float(self.dytpe),
|
||||
conv(n_feats, args.n_colors, kernel_size).to_float(self.dytpe)]
|
||||
|
||||
self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1).to_float(self.dytpe)
|
||||
|
||||
self.head = modules_head
|
||||
self.body = nn.SequentialCell(modules_body)
|
||||
self.tail = nn.SequentialCell(modules_tail)
|
||||
|
||||
def construct(self, x):
|
||||
"""rcan"""
|
||||
x = self.sub_mean(x)
|
||||
x = self.head(x)
|
||||
res = self.body(x)
|
||||
res += x
|
||||
x = self.tail(res)
|
||||
x = self.add_mean(x)
|
||||
return x
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train"""
|
||||
import os
|
||||
import time
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.nn as nn
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
from src.args import args
|
||||
from src.data.div2k import DIV2K
|
||||
from src.rcan_model import RCAN
|
||||
|
||||
def train():
|
||||
"""train"""
|
||||
set_seed(1)
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
rank_id = int(os.getenv('RANK_ID', '0'))
|
||||
device_num = int(os.getenv('RANK_SIZE', '1'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
|
||||
train_dataset.set_scale(args.task_id)
|
||||
|
||||
if args.modelArts_mode:
|
||||
import moxing as mox
|
||||
local_data_url = '/cache/data'
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
device_num=device_num, gradients_mean=True)
|
||||
mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_url)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
|
||||
train_dataset.set_scale(args.task_id)
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num,
|
||||
shard_id=rank_id, shuffle=True)
|
||||
train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
|
||||
net_m = RCAN(args)
|
||||
print("Init net weights successfully")
|
||||
|
||||
if args.ckpt_path:
|
||||
param_dict = load_checkpoint(args.pth_path)
|
||||
load_param_into_net(net_m, param_dict)
|
||||
print("Load net weight successfully")
|
||||
step_size = train_de_dataset.get_dataset_size()
|
||||
lr = []
|
||||
for i in range(0, args.epochs):
|
||||
cur_lr = args.lr / (2 ** ((i + 1) // 200))
|
||||
lr.extend([cur_lr] * step_size)
|
||||
opt = nn.Adam(net_m.trainable_params(), learning_rate=lr, loss_scale=1024.0)
|
||||
loss = nn.L1Loss()
|
||||
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=args.init_loss_scale, \
|
||||
scale_factor=2, scale_window=1000)
|
||||
model = Model(net_m, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager)
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args.ckpt_save_interval * step_size,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="rcan", directory=args.ckpt_save_path, config=config_ck)
|
||||
if device_id == 0:
|
||||
cb += [ckpt_cb]
|
||||
model.train(args.epochs, train_de_dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_start = time.time()
|
||||
train()
|
||||
time_end = time.time()
|
||||
print('train_time: %f' % (time_end - time_start))
|
Loading…
Reference in New Issue