forked from mindspore-Ecosystem/mindspore
!18924 add wdsr to model zoo
Merge pull request !18924 from Humanyue/master
This commit is contained in:
commit
303c354c49
|
@ -0,0 +1,297 @@
|
|||
目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [WDSR描述](#WDSR描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练)
|
||||
- [分布式训练](#分布式训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [模型导出](#模型导出)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [训练性能](#训练性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# WDSR描述
|
||||
|
||||
WDSR于2018年提出的 WDSR用于提高深度超分辨率网络的精度,它在 NTIRE 2018 单幅图像超分辨率挑战赛中获得了所有三个真实赛道的第一名。
|
||||
[论文1](https://arxiv.org/abs/1707.02921):Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **"Enhanced Deep Residual Networks for Single Image Super-Resolution,"** *2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**.[论文2](https://arxiv.org/abs/1808.08718): Jiahui Yu, Yuchen Fan, Jianchao Yang, Ning Xu, Zhaowen Wang, Xinchao Wang, Thomas Huang, **"Wide Activation for Efficient and Accurate Image Super-Resolution"**, arXiv preprint arXiv:1808.08718.
|
||||
|
||||
# 模型架构
|
||||
|
||||
WDSR网络主要由几个基本模块(包括卷积层和池化层)组成。通过更广泛的激活和线性低秩卷积,并引入权重归一化实现更高精度的单幅图像超分辨率。这里的基本模块主要包括以下基本操作: **1 × 1 卷积**和**3 × 3 卷积**。
|
||||
经过1次卷积层,再串联32个残差模块,再经过1次卷积层,最后上采样并卷积。
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集:[DIV2K](<http://www.vision.ee.ethz.ch/~timofter/publications/Agustsson-CVPRW-2017.pdf>)
|
||||
|
||||
- 数据集大小:7.11G,
|
||||
|
||||
- 训练集:共900张图像
|
||||
- 测试集:共100张图像
|
||||
|
||||
- 数据格式:png文件
|
||||
|
||||
- 注:数据将在src/data/DIV2K.py中处理。
|
||||
|
||||
```shell
|
||||
DIV2K
|
||||
├── DIV2K_test_LR_bicubic
|
||||
│ ├── X2
|
||||
│ │ ├── 0901x2.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 1000x2.png
|
||||
│ ├── X3
|
||||
│ │ ├── 0901x3.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 1000x3.png
|
||||
│ └── X4
|
||||
│ ├── 0901x4.png
|
||||
│ ├─ ...
|
||||
│ └── 1000x4.png
|
||||
├── DIV2K_test_LR_unknown
|
||||
│ ├── X2
|
||||
│ │ ├── 0901x2.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 1000x2.png
|
||||
│ ├── X3
|
||||
│ │ ├── 0901x3.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 1000x3.png
|
||||
│ └── X4
|
||||
│ ├── 0901x4.png
|
||||
│ ├─ ...
|
||||
│ └── 1000x4.png
|
||||
├── DIV2K_train_HR
|
||||
│ ├── 0001.png
|
||||
│ ├─ ...
|
||||
│ └── 0900.png
|
||||
├── DIV2K_train_LR_bicubic
|
||||
│ ├── X2
|
||||
│ │ ├── 0001x2.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 0900x2.png
|
||||
│ ├── X3
|
||||
│ │ ├── 0001x3.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 0900x3.png
|
||||
│ └── X4
|
||||
│ ├── 0001x4.png
|
||||
│ ├─ ...
|
||||
│ └── 0900x4.png
|
||||
└── DIV2K_train_LR_unknown
|
||||
├── X2
|
||||
│ ├── 0001x2.png
|
||||
│ ├─ ...
|
||||
│ └── 0900x2.png
|
||||
├── X3
|
||||
│ ├── 0001x3.png
|
||||
│ ├─ ...
|
||||
│ └── 0900x3.png
|
||||
└── X4
|
||||
├── 0001x4.png
|
||||
├─ ...
|
||||
└── 0900x4.png
|
||||
```
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 使用ascend处理器来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
```shell
|
||||
#单卡训练
|
||||
sh run_ascend_standalone.sh [TRAIN_DATA_DIR]
|
||||
```
|
||||
|
||||
```shell
|
||||
#分布式训练
|
||||
sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
|
||||
```
|
||||
|
||||
```python
|
||||
#评估
|
||||
sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```bash
|
||||
WDSR
|
||||
├── README_CN.md //自述文件
|
||||
├── eval.py //评估脚本
|
||||
├── export.py
|
||||
├── script
|
||||
│ ├── run_ascend_distribute.sh //Ascend分布式训练shell脚本
|
||||
│ ├── run_ascend_standalone.sh //Ascend单卡训练shell脚本
|
||||
│ └── run_eval.sh //eval验证shell脚本
|
||||
├── src
|
||||
│ ├── args.py //超参数
|
||||
│ ├── common.py //公共网络模块
|
||||
│ ├── data
|
||||
│ │ ├── common.py //公共数据集
|
||||
│ │ ├── div2k.py //div2k数据集
|
||||
│ │ └── srdata.py //所有数据集
|
||||
│ ├── metrics.py //PSNR和SSIM计算器
|
||||
│ ├── model.py //WDSR网络
|
||||
│ └── utils.py //训练脚本
|
||||
└── train.py //训练脚本
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
主要参数如下:
|
||||
|
||||
```python
|
||||
-h, --help show this help message and exit
|
||||
--dir_data DIR_DATA dataset directory
|
||||
--data_train DATA_TRAIN
|
||||
train dataset name
|
||||
--data_test DATA_TEST
|
||||
test dataset name
|
||||
--data_range DATA_RANGE
|
||||
train/test data range
|
||||
--ext EXT dataset file extension
|
||||
--scale SCALE super resolution scale
|
||||
--patch_size PATCH_SIZE
|
||||
output patch size
|
||||
--rgb_range RGB_RANGE
|
||||
maximum value of RGB
|
||||
--n_colors N_COLORS number of color channels to use
|
||||
--no_augment do not use data augmentation
|
||||
--model MODEL model name
|
||||
--n_resblocks N_RESBLOCKS
|
||||
number of residual blocks
|
||||
--n_feats N_FEATS number of feature maps
|
||||
--res_scale RES_SCALE
|
||||
residual scaling
|
||||
--test_every TEST_EVERY
|
||||
do test per every N batches
|
||||
--epochs EPOCHS number of epochs to train
|
||||
--batch_size BATCH_SIZE
|
||||
input batch size for training
|
||||
--test_only set this option to test the model
|
||||
--lr LR learning rate
|
||||
--ckpt_save_path CKPT_SAVE_PATH
|
||||
path to save ckpt
|
||||
--ckpt_save_interval CKPT_SAVE_INTERVAL
|
||||
save ckpt frequency, unit is epoch
|
||||
--ckpt_save_max CKPT_SAVE_MAX
|
||||
max number of saved ckpt
|
||||
--ckpt_path CKPT_PATH
|
||||
path of saved ckpt
|
||||
--task_id TASK_ID
|
||||
|
||||
```
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
sh run_ascend_standalone.sh [TRAIN_DATA_DIR]
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过train.log文件查看结果。
|
||||
|
||||
### 分布式训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
|
||||
```
|
||||
|
||||
TRAIN_DATA_DIR = "~DATA/"。
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 评估
|
||||
|
||||
在运行以下命令之前,请检查用于评估的检查点路径。
|
||||
|
||||
```bash
|
||||
sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] DIV2K
|
||||
```
|
||||
|
||||
TEST_DATA_DIR = "~DATA/"。
|
||||
|
||||
您可以通过eval.log文件查看结果。
|
||||
|
||||
# 模型导出
|
||||
|
||||
```bash
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
FILE_FORMAT 可选 ['MINDIR', 'AIR', 'ONNX'], 默认['MINDIR']。
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 训练性能
|
||||
|
||||
| 参数 | Ascend |
|
||||
| ------------- | ------------------------------------------------------------ |
|
||||
| 资源 | Ascend 910 |
|
||||
| 上传日期 | 2021-7-4 |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | DIV2K |
|
||||
| 训练参数 | epoch=1000, steps=100, batch_size =16, lr=0.0001 |
|
||||
| 优化器 | Adam |
|
||||
| 损失函数 | L1 |
|
||||
| 输出 | 超分辨率图片 |
|
||||
| 损失 | 3.5 |
|
||||
| 速度 | 8卡:约130毫秒/步 |
|
||||
| 总时长 | 8卡:0.5小时 |
|
||||
| 微调检查点 | 35 MB(.ckpt文件) |
|
||||
| 脚本 | [WDSR](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/wdsr) |
|
||||
|
||||
### 评估性能
|
||||
|
||||
| 参数 | Ascend |
|
||||
| ------------- | ----------------------------------------------------------- |
|
||||
| 资源 | Ascend 910 |
|
||||
| 上传日期 | 2021-7-4 |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | DIV2K |
|
||||
| batch_size | 1 |
|
||||
| 输出 | 超分辨率图片 |
|
||||
| PSNR | DIV2K 34.7780 |
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
在train.py中,我们设置了“train_net”函数内的种子。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""wdsr eval script"""
|
||||
import os
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.args import args
|
||||
import src.model as wdsr
|
||||
from src.data.srdata import SRData
|
||||
from src.data.div2k import DIV2K
|
||||
from src.metrics import calc_psnr, quantize, calc_ssim
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
|
||||
context.set_context(max_call_depth=10000)
|
||||
def eval_net():
|
||||
"""eval"""
|
||||
if args.epochs == 0:
|
||||
args.epochs = 1e8
|
||||
for arg in vars(args):
|
||||
if vars(args)[arg] == 'True':
|
||||
vars(args)[arg] = True
|
||||
elif vars(args)[arg] == 'False':
|
||||
vars(args)[arg] = False
|
||||
if args.data_test[0] == 'DIV2K':
|
||||
train_dataset = DIV2K(args, name=args.data_test, train=False, benchmark=False)
|
||||
else:
|
||||
train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False)
|
||||
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
|
||||
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
|
||||
net_m = wdsr.WDSR()
|
||||
if args.ckpt_path:
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(net_m, param_dict)
|
||||
net_m.set_train(False)
|
||||
print('load mindspore net successfully.')
|
||||
num_imgs = train_de_dataset.get_dataset_size()
|
||||
psnrs = np.zeros((num_imgs, 1))
|
||||
ssims = np.zeros((num_imgs, 1))
|
||||
for batch_idx, imgs in enumerate(train_loader):
|
||||
lr = imgs['LR']
|
||||
hr = imgs['HR']
|
||||
lr = Tensor(lr, mstype.float32)
|
||||
pred = net_m(lr)
|
||||
pred_np = pred.asnumpy()
|
||||
pred_np = quantize(pred_np, 255)
|
||||
psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0)
|
||||
pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0)
|
||||
hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0)
|
||||
ssim = calc_ssim(pred_np, hr, args.scale[0])
|
||||
print("current psnr: ", psnr)
|
||||
print("current ssim: ", ssim)
|
||||
psnrs[batch_idx, 0] = psnr
|
||||
ssims[batch_idx, 0] = ssim
|
||||
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
|
||||
print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0]))
|
||||
if __name__ == '__main__':
|
||||
print("Start eval function!")
|
||||
eval_net()
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export net together with checkpoint into air/mindir/onnx models"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
import src.model as wdsr
|
||||
|
||||
parser = argparse.ArgumentParser(description='wdsr export')
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_path", type=str, required=True, help="path of checkpoint file")
|
||||
parser.add_argument("--file_name", type=str, default="wdsr", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, default="MINDIR", choices=['MINDIR', 'AIR', 'ONNX'], help="file format")
|
||||
args = parser.parse_args()
|
||||
|
||||
def run_export(args1):
|
||||
"""run_export"""
|
||||
device_id = int(os.getenv("DEVICE_ID", '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||
net = wdsr.WDSR()
|
||||
param_dict = load_checkpoint(args1.ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
print('load mindspore net and checkpoint successfully.')
|
||||
inputs = Tensor(np.zeros([args1.batch_size, 3, 678, 1020], np.float32))
|
||||
export(net, inputs, file_name=args.file_name, file_format=args.file_format)
|
||||
print('export successfully!')
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_export(args)
|
|
@ -0,0 +1,69 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
if [ ! -f $PATH1 ]; then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]; then
|
||||
echo "error: TRAIN_DATA_DIR=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
|
||||
nohup python train.py \
|
||||
--batch_size 16 \
|
||||
--lr 1e-4 \
|
||||
--scale 2 \
|
||||
--task_id 0 \
|
||||
--dir_data $PATH2 \
|
||||
--epochs 1000 \
|
||||
--test_every 100 \
|
||||
--patch_size 48 > train.log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,72 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
#!/bin/bash
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]; then
|
||||
echo "Usage: sh run_standalone_train.sh [TRAIN_DATA_DIR]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
if [ -d "train" ]; then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
|
||||
env >env.log
|
||||
|
||||
nohup python train.py \
|
||||
--batch_size 16 \
|
||||
--lr 1e-4 \
|
||||
--scale 2 \
|
||||
--task_id 0 \
|
||||
--dir_data $PATH1 \
|
||||
--epochs 1000 \
|
||||
--test_every 100 \
|
||||
--patch_size 48 > train.log 2>&1 &
|
|
@ -0,0 +1,61 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
DATASET_TYPE=$3
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: TEST_DATA_DIR=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]; then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -d "eval" ]; then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
echo "start evaluation ..."
|
||||
|
||||
python eval.py \
|
||||
--dir_data=${PATH1} \
|
||||
--test_only \
|
||||
--ext "img" \
|
||||
--data_test=${DATASET_TYPE} \
|
||||
--ckpt_path=${PATH2} \
|
||||
--task_id 0 \
|
||||
--scale 2 > eval.log 2>&1 &
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""args parser"""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='EDSR and MDSR')
|
||||
# Data specifications
|
||||
parser.add_argument('--dir_data', type=str, default='/cache/data/',
|
||||
help='dataset directory')
|
||||
parser.add_argument('--data_train', type=str, default='DIV2K',
|
||||
help='train dataset name')
|
||||
parser.add_argument('--data_test', type=str, default='DIV2K',
|
||||
help='test dataset name')
|
||||
parser.add_argument('--data_range', type=str, default='1-800/801-900',
|
||||
help='train/test data range')
|
||||
parser.add_argument('--ext', type=str, default='sep',
|
||||
help='dataset file extension')
|
||||
parser.add_argument('--scale', type=str, default='4',
|
||||
help='super resolution scale')
|
||||
parser.add_argument('--patch_size', type=int, default=48,
|
||||
help='output patch size')
|
||||
parser.add_argument('--rgb_range', type=int, default=255,
|
||||
help='maximum value of RGB')
|
||||
parser.add_argument('--n_colors', type=int, default=3,
|
||||
help='number of color channels to use')
|
||||
parser.add_argument('--no_augment', action='store_true',
|
||||
help='do not use data augmentation')
|
||||
# Model specifications
|
||||
parser.add_argument('--model', default='WDSR',
|
||||
help='model name')
|
||||
parser.add_argument('--n_resblocks', type=int, default=16,
|
||||
help='number of residual blocks')
|
||||
parser.add_argument('--n_feats', type=int, default=64,
|
||||
help='number of feature maps')
|
||||
parser.add_argument('--res_scale', type=float, default=1,
|
||||
help='residual scaling')
|
||||
# Training specifications
|
||||
parser.add_argument('--test_every', type=int, default=1000,
|
||||
help='do test per every N batches')
|
||||
parser.add_argument('--epochs', type=int, default=300,
|
||||
help='number of epochs to train')
|
||||
parser.add_argument('--batch_size', type=int, default=16,
|
||||
help='input batch size for training')
|
||||
#parser.add_argument('--self_ensemble', action='store_true',
|
||||
# help='use self-ensemble method for test')
|
||||
parser.add_argument('--test_only', action='store_true',
|
||||
help='set this option to test the model')
|
||||
# Optimization specifications
|
||||
parser.add_argument('--lr', type=float, default=1e-4,
|
||||
help='learning rate')
|
||||
parser.add_argument('--init_loss_scale', type=float, default=65536.,
|
||||
help='scaling factor')
|
||||
parser.add_argument('--decay', type=str, default='200',
|
||||
help='learning rate decay type')
|
||||
parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
|
||||
help='ADAM beta')
|
||||
parser.add_argument('--epsilon', type=float, default=1e-8,
|
||||
help='ADAM epsilon for numerical stability')
|
||||
parser.add_argument('--weight_decay', type=float, default=0,
|
||||
help='weight decay')
|
||||
parser.add_argument('--gclip', type=float, default=0,
|
||||
help='gradient clipping threshold (0 = no clipping)')
|
||||
# ckpt specifications
|
||||
parser.add_argument('--ckpt_save_path', type=str, default='./ckpt/',
|
||||
help='path to save ckpt')
|
||||
parser.add_argument('--ckpt_save_interval', type=int, default=10,
|
||||
help='save ckpt frequency, unit is epoch')
|
||||
parser.add_argument('--ckpt_save_max', type=int, default=5,
|
||||
help='max number of saved ckpt')
|
||||
parser.add_argument('--ckpt_path', type=str, default='',
|
||||
help='path of saved ckpt')
|
||||
# alltask
|
||||
parser.add_argument('--task_id', type=int, default=0)
|
||||
# rgb_mean
|
||||
parser.add_argument('--r_mean', type=float, default=0.4488,
|
||||
help='r_mean')
|
||||
parser.add_argument('--g_mean', type=float, default=0.4371,
|
||||
help='g_mean')
|
||||
parser.add_argument('--b_mean', type=float, default=0.4040,
|
||||
help='b_mean')
|
||||
args, unparsed = parser.parse_known_args()
|
||||
args.scale = [int(x) for x in args.scale.split("+")]
|
||||
args.data_train = args.data_train.split('+')
|
||||
args.data_test = args.data_test.split('+')
|
||||
if args.epochs == 0:
|
||||
args.epochs = 1e4
|
||||
for arg in vars(args):
|
||||
if vars(args)[arg] == 'True':
|
||||
vars(args)[arg] = True
|
||||
elif vars(args)[arg] == 'False':
|
||||
vars(args)[arg] = False
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""common"""
|
||||
import random
|
||||
import numpy as np
|
||||
def get_patch(*args, patch_size=96, scale=2, input_large=False):
|
||||
"""common"""
|
||||
ih, iw = args[0].shape[:2]
|
||||
tp = patch_size
|
||||
ip = tp // scale
|
||||
# print("tp=%g,scale=%g,"%(tp,scale),end='',flush=True)
|
||||
# print("ih=%g,iw=%g,ip=%g"%(ih,iw,ip),flush=True)
|
||||
ix = random.randrange(0, iw - ip + 1)
|
||||
iy = random.randrange(0, ih - ip + 1)
|
||||
# ix = iy = 0
|
||||
if not input_large:
|
||||
tx, ty = scale * ix, scale * iy
|
||||
else:
|
||||
tx, ty = ix, iy
|
||||
ret = [args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]]
|
||||
return ret
|
||||
def set_channel(*args, n_channels=3):
|
||||
"""common"""
|
||||
def _set_channel(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
c = img.shape[2]
|
||||
if n_channels == 3 and c == 1:
|
||||
img = np.concatenate([img] * n_channels, 2)
|
||||
return img[:, :, :n_channels]
|
||||
return [_set_channel(a) for a in args]
|
||||
def np2Tensor(*args, rgb_range=255):
|
||||
def _np2Tensor(img):
|
||||
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
|
||||
input_data = np_transpose.astype(np.float32)
|
||||
output = input_data * (rgb_range / 255)
|
||||
return output
|
||||
return [_np2Tensor(a) for a in args]
|
||||
def augment(*args, hflip=True, rot=True):
|
||||
"""common"""
|
||||
hflip = hflip and random.random() < 0.5
|
||||
vflip = rot and random.random() < 0.5
|
||||
rot90 = rot and random.random() < 0.5
|
||||
def _augment(img):
|
||||
"""common"""
|
||||
if hflip:
|
||||
img = img[:, ::-1, :]
|
||||
if vflip:
|
||||
img = img[::-1, :, :]
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
return [_augment(a) for a in args]
|
||||
def search(root, target="JPEG"):
|
||||
"""srdata"""
|
||||
item_list = []
|
||||
items = os.listdir(root)
|
||||
for item in items:
|
||||
path = os.path.join(root, item)
|
||||
if os.path.isdir(path):
|
||||
item_list.extend(search(path, target))
|
||||
elif path.split('/')[-1].startswith(target):
|
||||
item_list.append(path)
|
||||
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
|
||||
item_list.append(path)
|
||||
return item_list
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""div2k"""
|
||||
import os
|
||||
from src.data.srdata import SRData
|
||||
class DIV2K(SRData):
|
||||
"""DIV2K"""
|
||||
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
|
||||
data_range = [r.split('-') for r in args.data_range.split('/')]
|
||||
if train:
|
||||
data_range = data_range[0]
|
||||
else:
|
||||
if args.test_only and len(data_range) == 1:
|
||||
data_range = data_range[0]
|
||||
else:
|
||||
data_range = data_range[1]
|
||||
self.begin, self.end = list(map(int, data_range))
|
||||
super(DIV2K, self).__init__(args, name=name, train=train, benchmark=benchmark)
|
||||
self.dir_hr = None
|
||||
self.dir_lr = None
|
||||
|
||||
def _scan(self):
|
||||
names_hr, names_lr = super(DIV2K, self)._scan()
|
||||
names_hr = names_hr[self.begin - 1:self.end]
|
||||
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
|
||||
return names_hr, names_lr
|
||||
|
||||
def _set_filesystem(self, dir_data):
|
||||
super(DIV2K, self)._set_filesystem(dir_data)
|
||||
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
|
||||
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
|
|
@ -0,0 +1,215 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""""srdata"""
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
import pickle
|
||||
import imageio
|
||||
from src.data import common
|
||||
from PIL import ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
class SRData:
|
||||
"""srdata"""
|
||||
def __init__(self, args, name='', train=True, benchmark=False):
|
||||
self.args = args
|
||||
self.name = name
|
||||
self.train = train
|
||||
self.split = 'train' if train else 'test'
|
||||
self.do_eval = True
|
||||
self.benchmark = benchmark
|
||||
self.input_large = (args.model == 'VDSR')
|
||||
self.scale = args.scale
|
||||
self.idx_scale = 0
|
||||
self._set_filesystem(args.dir_data)
|
||||
self._set_img(args)
|
||||
if train:
|
||||
self._repeat(args)
|
||||
|
||||
def _set_img(self, args):
|
||||
"""srdata"""
|
||||
if args.ext.find('img') < 0:
|
||||
path_bin = os.path.join(self.apath, 'bin')
|
||||
os.makedirs(path_bin, exist_ok=True)
|
||||
list_hr, list_lr = self._scan()
|
||||
if args.ext.find('img') >= 0 or self.benchmark:
|
||||
self.images_hr, self.images_lr = list_hr, list_lr
|
||||
elif args.ext.find('sep') >= 0:
|
||||
os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True)
|
||||
for s in self.scale:
|
||||
if s == 1:
|
||||
os.makedirs(os.path.join(self.dir_hr), exist_ok=True)
|
||||
else:
|
||||
os.makedirs(
|
||||
os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True)
|
||||
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
|
||||
for h in list_hr:
|
||||
b = h.replace(self.apath, path_bin)
|
||||
b = b.replace(self.ext[0], '.pt')
|
||||
self.images_hr.append(b)
|
||||
self._check_and_load(args.ext, h, b, verbose=True)
|
||||
for i, ll in enumerate(list_lr):
|
||||
for l in ll:
|
||||
b = l.replace(self.apath, path_bin)
|
||||
b = b.replace(self.ext[1], '.pt')
|
||||
self.images_lr[i].append(b)
|
||||
self._check_and_load(args.ext, l, b, verbose=True)
|
||||
|
||||
def _repeat(self, args):
|
||||
"""srdata"""
|
||||
n_patches = args.batch_size * args.test_every
|
||||
n_images = len(args.data_train) * len(self.images_hr)
|
||||
if n_images == 0:
|
||||
self.repeat = 0
|
||||
else:
|
||||
self.repeat = max(n_patches // n_images, 1)
|
||||
|
||||
def _scan(self):
|
||||
"""srdata"""
|
||||
names_hr = sorted(
|
||||
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])))
|
||||
names_lr = [[] for _ in self.scale]
|
||||
for f in names_hr:
|
||||
filename, _ = os.path.splitext(os.path.basename(f))
|
||||
for si, s in enumerate(self.scale):
|
||||
if s != 1:
|
||||
scale = s
|
||||
names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \
|
||||
.format(s, filename, scale, self.ext[1])))
|
||||
for si, s in enumerate(self.scale):
|
||||
if s == 1:
|
||||
names_lr[si] = names_hr
|
||||
return names_hr, names_lr
|
||||
|
||||
def _set_filesystem(self, dir_data):
|
||||
self.apath = os.path.join(dir_data, self.name[0])
|
||||
self.dir_hr = os.path.join(self.apath, 'HR')
|
||||
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
|
||||
self.ext = ('.png', '.png')
|
||||
|
||||
def _check_and_load(self, ext, img, f, verbose=True):
|
||||
if not os.path.isfile(f) or ext.find('reset') >= 0:
|
||||
if verbose:
|
||||
print('Making a binary: {}'.format(f))
|
||||
with open(f, 'wb') as _f:
|
||||
pickle.dump(imageio.imread(img), _f)
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
def __getitem__(self, idx):
|
||||
lr, hr, filename = self._load_file(idx)
|
||||
pair = self.get_patch(lr, hr)
|
||||
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
|
||||
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
||||
#return pair_t[0], pair_t[1], [self.idx_scale], [filename]
|
||||
return pair_t[0], pair_t[1]
|
||||
|
||||
def __len__(self):
|
||||
if self.train:
|
||||
return len(self.images_hr) * self.repeat
|
||||
return len(self.images_hr)
|
||||
|
||||
def _get_index(self, idx):
|
||||
if self.train:
|
||||
return idx % len(self.images_hr)
|
||||
return idx
|
||||
|
||||
def _load_file_deblur(self, idx, train=True):
|
||||
"""srdata"""
|
||||
idx = self._get_index(idx)
|
||||
if train:
|
||||
f_hr = self.images_hr[idx]
|
||||
f_lr = self.images_lr[idx]
|
||||
else:
|
||||
f_hr = self.deblur_hr_test[idx]
|
||||
f_lr = self.deblur_lr_test[idx]
|
||||
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
||||
filename = f_hr[-27:-17] + filename
|
||||
hr = imageio.imread(f_hr)
|
||||
lr = imageio.imread(f_lr)
|
||||
return lr, hr, filename
|
||||
|
||||
def _load_file_hr(self, idx):
|
||||
"""srdata"""
|
||||
idx = self._get_index(idx)
|
||||
f_hr = self.images_hr[idx]
|
||||
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
||||
if self.args.ext == 'img' or self.benchmark:
|
||||
hr = imageio.imread(f_hr)
|
||||
elif self.args.ext.find('sep') >= 0:
|
||||
with open(f_hr, 'rb') as _f:
|
||||
hr = pickle.load(_f)
|
||||
return hr, filename
|
||||
|
||||
def _load_rain_test(self, idx):
|
||||
f_hr = self.derain_hr_test[idx]
|
||||
f_lr = self.derain_lr_test[idx]
|
||||
filename, _ = os.path.splitext(os.path.basename(f_lr))
|
||||
norain = imageio.imread(f_hr)
|
||||
rain = imageio.imread(f_lr)
|
||||
return norain, rain, filename
|
||||
|
||||
def _load_file(self, idx):
|
||||
"""srdata"""
|
||||
idx = self._get_index(idx)
|
||||
# print(idx,flush=True)
|
||||
f_hr = self.images_hr[idx]
|
||||
f_lr = self.images_lr[self.idx_scale][idx]
|
||||
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
||||
if self.args.ext == 'img' or self.benchmark:
|
||||
hr = imageio.imread(f_hr)
|
||||
lr = imageio.imread(f_lr)
|
||||
elif self.args.ext.find('sep') >= 0:
|
||||
with open(f_hr, 'rb') as _f:
|
||||
hr = pickle.load(_f)
|
||||
with open(f_lr, 'rb') as _f:
|
||||
lr = pickle.load(_f)
|
||||
return lr, hr, filename
|
||||
|
||||
def get_patch_hr(self, hr):
|
||||
"""srdata"""
|
||||
if self.train:
|
||||
hr = self.get_patch_img_hr(hr, patch_size=self.args.patch_size, scale=1)
|
||||
return hr
|
||||
|
||||
def get_patch_img_hr(self, img, patch_size=96, scale=2):
|
||||
"""srdata"""
|
||||
ih, iw = img.shape[:2]
|
||||
tp = patch_size
|
||||
ip = tp // scale
|
||||
ix = random.randrange(0, iw - ip + 1)
|
||||
iy = random.randrange(0, ih - ip + 1)
|
||||
ret = img[iy:iy + ip, ix:ix + ip, :]
|
||||
return ret
|
||||
|
||||
def get_patch(self, lr, hr):
|
||||
"""srdata"""
|
||||
scale = self.scale[self.idx_scale]
|
||||
if self.train:
|
||||
lr, hr = common.get_patch(
|
||||
lr, hr,
|
||||
patch_size=self.args.patch_size * scale,
|
||||
scale=scale)
|
||||
if not self.args.no_augment:
|
||||
lr, hr = common.augment(lr, hr)
|
||||
else:
|
||||
ih, iw = lr.shape[:2]
|
||||
hr = hr[0:ih * scale, 0:iw * scale]
|
||||
return lr, hr
|
||||
|
||||
def set_scale(self, idx_scale):
|
||||
if not self.input_large:
|
||||
self.idx_scale = idx_scale
|
||||
else:
|
||||
self.idx_scale = random.randint(0, len(self.scale) - 1)
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""metrics"""
|
||||
import math
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def quantize(img, rgb_range):
|
||||
"""quantize image range to 0-255"""
|
||||
pixel_range = 255 / rgb_range
|
||||
img = np.multiply(img, pixel_range)
|
||||
img = np.clip(img, 0, 255)
|
||||
img = np.round(img) / pixel_range
|
||||
return img
|
||||
|
||||
|
||||
def calc_psnr(sr, hr, scale, rgb_range):
|
||||
"""calculate psnr"""
|
||||
hr = np.float32(hr)
|
||||
sr = np.float32(sr)
|
||||
diff = (sr - hr) / rgb_range
|
||||
gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256
|
||||
diff = np.multiply(diff, gray_coeffs).sum(1)
|
||||
if hr.size == 1:
|
||||
return 0
|
||||
if scale != 1:
|
||||
shave = scale
|
||||
else:
|
||||
shave = scale + 6
|
||||
if scale == 1:
|
||||
valid = diff
|
||||
else:
|
||||
valid = diff[..., shave:-shave, shave:-shave]
|
||||
mse = np.mean(pow(valid, 2))
|
||||
return -10 * math.log10(mse)
|
||||
|
||||
|
||||
def rgb2ycbcr(img, y_only=True):
|
||||
"""from rgb space to ycbcr space"""
|
||||
img.astype(np.float32)
|
||||
if y_only:
|
||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
||||
return rlt
|
||||
|
||||
|
||||
def calc_ssim(img1, img2, scale):
|
||||
"""calculate ssim value"""
|
||||
def ssim(img1, img2):
|
||||
C1 = (0.01 * 255) ** 2
|
||||
C2 = (0.03 * 255) ** 2
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1 ** 2
|
||||
mu2_sq = mu2 ** 2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||
(sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
border = 0
|
||||
if scale != 1:
|
||||
border = scale
|
||||
else:
|
||||
border = scale + 6
|
||||
img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 256.0 + 16.0
|
||||
img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 256.0 + 16.0
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError('Input images must have the same dimensions.')
|
||||
h, w = img1.shape[:2]
|
||||
img1_y = img1_y[border:h - border, border:w - border]
|
||||
img2_y = img2_y[border:h - border, border:w - border]
|
||||
if img1_y.ndim == 2:
|
||||
return ssim(img1_y, img2_y)
|
||||
if img1.ndim == 3:
|
||||
if img1.shape[2] == 3:
|
||||
ssims = []
|
||||
for _ in range(3):
|
||||
ssims.append(ssim(img1, img2))
|
||||
return np.array(ssims).mean()
|
||||
if img1.shape[2] == 1:
|
||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
||||
else:
|
||||
raise ValueError('Wrong input image dimensions.')
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""main model of wdsr"""
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MeanShift(mindspore.nn.Conv2d):
|
||||
"""add or sub means of input data"""
|
||||
def __init__(
|
||||
self, rgb_range,
|
||||
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1, dtype=mindspore.float32):
|
||||
std = mindspore.Tensor(rgb_std, dtype)
|
||||
weight = mindspore.Tensor(np.eye(3), dtype).reshape(3, 3, 1, 1) / std.reshape(3, 1, 1, 1)
|
||||
bias = sign * rgb_range * mindspore.Tensor(rgb_mean, dtype) / std
|
||||
super(MeanShift, self).__init__(3, 3, kernel_size=1, has_bias=True, weight_init=weight, bias_init=bias)
|
||||
for p in self.get_parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
|
||||
class Block(nn.Cell):
|
||||
"""residual block"""
|
||||
def __init__(self):
|
||||
super(Block, self).__init__()
|
||||
# wn = lambda x: mindspore.nn.GroupNorm(x)
|
||||
act = nn.ReLU()
|
||||
self.res_scale = 1
|
||||
body = []
|
||||
expand = 6
|
||||
linear = 0.8
|
||||
body.append(nn.Conv2d(64, 64 * expand, 1, padding=(1 // 2), pad_mode='pad', has_bias=True))
|
||||
body.append(act)
|
||||
body.append(nn.Conv2d(64 * expand, int(64 * linear), 1, padding=1 // 2, pad_mode='pad', has_bias=True))
|
||||
body.append(nn.Conv2d(int(64 * linear), 64, 3, padding=3 // 2, pad_mode='pad', has_bias=True))
|
||||
self.body = nn.SequentialCell(body)
|
||||
|
||||
def construct(self, x):
|
||||
res = self.body(x)
|
||||
res += x
|
||||
return res
|
||||
|
||||
|
||||
class PixelShuffle(nn.Cell):
|
||||
"""perform pixel shuffle"""
|
||||
def __init__(self, upscale_factor):
|
||||
super().__init__()
|
||||
self.DepthToSpace = mindspore.ops.DepthToSpace(upscale_factor)
|
||||
|
||||
def construct(self, x):
|
||||
return self.DepthToSpace(x)
|
||||
|
||||
|
||||
class WDSR(nn.Cell):
|
||||
"""main structure of wdsr"""
|
||||
def __init__(self):
|
||||
super(WDSR, self).__init__()
|
||||
scale = 2
|
||||
n_resblocks = 8
|
||||
n_feats = 64
|
||||
self.sub_mean = MeanShift(255)
|
||||
self.add_mean = MeanShift(255, sign=1)
|
||||
# define head module
|
||||
head = []
|
||||
head.append(
|
||||
nn.Conv2d(
|
||||
3, n_feats, 3,
|
||||
pad_mode='pad',
|
||||
padding=(3 // 2), has_bias=True))
|
||||
|
||||
# define body module
|
||||
body = []
|
||||
for _ in range(n_resblocks):
|
||||
body.append(Block())
|
||||
# define tail module
|
||||
tail = []
|
||||
out_feats = scale * scale * 3
|
||||
tail.append(
|
||||
nn.Conv2d(n_feats, out_feats, 3, padding=3 // 2, pad_mode='pad', has_bias=True))
|
||||
self.depth_to_space = mindspore.ops.DepthToSpace(2)
|
||||
skip = []
|
||||
skip.append(
|
||||
nn.Conv2d(3, out_feats, 5, padding=5 // 2, pad_mode='pad', has_bias=True))
|
||||
self.head = nn.SequentialCell(head)
|
||||
self.body = nn.SequentialCell(body)
|
||||
self.tail = nn.SequentialCell(tail)
|
||||
self.skip = nn.SequentialCell(skip)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.sub_mean(x) / 127.5
|
||||
s = self.skip(x)
|
||||
x = self.head(x)
|
||||
x = self.body(x)
|
||||
x = self.tail(x)
|
||||
x += s
|
||||
x = self.depth_to_space(x)
|
||||
x = self.add_mean(x * 127.5)
|
||||
return x
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""wdsr train wrapper"""
|
||||
import os
|
||||
import time
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
class Trainer():
|
||||
"""Trainer"""
|
||||
def __init__(self, args, loader, my_model):
|
||||
self.args = args
|
||||
self.scale = args.scale
|
||||
self.trainloader = loader
|
||||
self.model = my_model
|
||||
self.model.set_train()
|
||||
self.criterion = nn.L1Loss()
|
||||
self.loss_history = []
|
||||
self.begin_time = time.time()
|
||||
self.optimizer = nn.Adam(self.model.trainable_params(), learning_rate=args.lr, loss_scale=1024.0)
|
||||
self.loss_net = nn.WithLossCell(self.model, self.criterion)
|
||||
self.net = nn.TrainOneStepCell(self.loss_net, self.optimizer)
|
||||
def train(self, epoch):
|
||||
"""Trainer"""
|
||||
losses = 0
|
||||
batch_idx = 0
|
||||
for batch_idx, imgs in enumerate(self.trainloader):
|
||||
lr = imgs["LR"]
|
||||
hr = imgs["HR"]
|
||||
lr = Tensor(lr, mstype.float32)
|
||||
hr = Tensor(hr, mstype.float32)
|
||||
t1 = time.time()
|
||||
loss = self.net(lr, hr)
|
||||
t2 = time.time()
|
||||
losses += loss.asnumpy()
|
||||
print('Epoch: %g, Step: %g , loss: %f, time: %f s ' % \
|
||||
(epoch, batch_idx, loss.asnumpy(), t2 - t1), end='\n', flush=True)
|
||||
print("the epoch loss is", losses / (batch_idx + 1), flush=True)
|
||||
self.loss_history.append(losses / (batch_idx + 1))
|
||||
print(self.loss_history)
|
||||
t = time.time() - self.begin_time
|
||||
t = int(t)
|
||||
print(", running time: %gh%g'%g''"%(t//3600, (t-t//3600*3600)//60, t%60), flush=True)
|
||||
os.makedirs(self.args.save, exist_ok=True)
|
||||
if self.args.rank == 0 and (epoch+1)%10 == 0:
|
||||
save_checkpoint(self.net, self.args.save + "model_" + str(self.epoch) + '.ckpt')
|
||||
def update_learning_rate(self, epoch):
|
||||
"""Update learning rates for all the networks; called at the end of every epoch.
|
||||
:param epoch: current epoch
|
||||
:type epoch: int
|
||||
:param lr: learning rate of cyclegan
|
||||
:type lr: float
|
||||
:param niter: number of epochs with the initial learning rate
|
||||
:type niter: int
|
||||
:param niter_decay: number of epochs to linearly decay learning rate to zero
|
||||
:type niter_decay: int
|
||||
"""
|
||||
self.epoch = epoch
|
||||
print("*********** epoch: {} **********".format(epoch))
|
||||
lr = self.args.lr / (2 ** ((epoch+1)//200))
|
||||
self.adjust_lr('model', self.optimizer, lr)
|
||||
print("*********************************")
|
||||
def adjust_lr(self, name, optimizer, lr):
|
||||
"""Adjust learning rate for the corresponding model.
|
||||
:param name: name of model
|
||||
:type name: str
|
||||
:param optimizer: the optimizer of the corresponding model
|
||||
:type optimizer: torch.optim
|
||||
:param lr: learning rate to be adjusted
|
||||
:type lr: float
|
||||
"""
|
||||
lr_param = optimizer.get_lr()
|
||||
lr_param.assign_value(Tensor(lr, mstype.float32))
|
||||
print('==> ' + name + ' learning rate: ', lr_param.asnumpy())
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""wdsr train script"""
|
||||
import os
|
||||
from mindspore import context
|
||||
from mindspore import dataset as ds
|
||||
import mindspore.nn as nn
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
from src.args import args
|
||||
from src.data.div2k import DIV2K
|
||||
from src.model import WDSR
|
||||
|
||||
def train_net():
|
||||
"""train wdsr"""
|
||||
set_seed(1)
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
rank_id = int(os.getenv('RANK_ID', '0'))
|
||||
device_num = int(os.getenv('RANK_SIZE', '1'))
|
||||
# if distribute:
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
device_num=device_num, gradients_mean=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
|
||||
train_dataset.set_scale(args.task_id)
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num,
|
||||
shard_id=rank_id, shuffle=True)
|
||||
train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
|
||||
net_m = WDSR()
|
||||
print("Init net successfully")
|
||||
if args.ckpt_path:
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(net_m, param_dict)
|
||||
print("Load net weight successfully")
|
||||
step_size = train_de_dataset.get_dataset_size()
|
||||
lr = []
|
||||
for i in range(0, args.epochs):
|
||||
cur_lr = args.lr / (2 ** ((i + 1)//200))
|
||||
lr.extend([cur_lr] * step_size)
|
||||
opt = nn.Adam(net_m.trainable_params(), learning_rate=lr, loss_scale=1024.0)
|
||||
loss = nn.L1Loss()
|
||||
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=args.init_loss_scale, \
|
||||
scale_factor=2, scale_window=1000)
|
||||
model = Model(net_m, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager)
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args.ckpt_save_interval * step_size,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="wdsr", directory=args.ckpt_save_path, config=config_ck)
|
||||
if device_id == 0:
|
||||
cb += [ckpt_cb]
|
||||
model.train(args.epochs, train_de_dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_net()
|
Loading…
Reference in New Issue