!18991 add rcan to modelzoo

Merge pull request !18991 from 罗柄淳/master
This commit is contained in:
i-robot 2021-07-13 11:41:53 +00:00 committed by Gitee
commit 9f8502598f
16 changed files with 1519 additions and 0 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 97 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 350 KiB

View File

@ -0,0 +1,293 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [RCAN描述](#RCAN描述)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练](#训练)
- [评估](#评估)
- [参数配置](#参数配置)
- [训练过程](#训练过程)
- [训练](#训练-1)
- [评估过程](#评估过程)
- [评估](#评估-1)
- [模型导出](#模型导出)
- [模型描述](#模型描述)
- [性能](#性能)
- [训练性能](#训练性能)
- [评估性能](#评估性能)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# RCAN描述
卷积神经网络CNN深度是图像超分辨率SR的关键。然而我们观察到图像SR的更深的网络更难训练。低分辨率的输入和特征包含了丰富的低频信息这些信息在不同的信道中被平等地对待从而阻碍了CNNs的表征能力。为了解决这些问题我们提出了超深剩余信道注意网络RCAN。具体地说我们提出了一种残差中残差RIR结构来形成非常深的网络它由多个具有长跳跃连接的残差组组成。每个剩余组包含一些具有短跳过连接的剩余块。同时RIR允许通过多跳连接绕过丰富的低频信息使主网集中学习高频信息。此外我们提出了一种通道注意机制通过考虑通道间的相互依赖性自适应地重新缩放通道特征。大量的实验表明我们的RCAN与现有的方法相比具有更好的精确度和视觉效果。
![CA](https://gitee.com/bcc2974874275/mindspore/raw/master/model_zoo/research/cv/RCAN/Figs/CA.PNG)
通道注意CA结构。
![RCAB](https://gitee.com/bcc2974874275/mindspore/raw/master/model_zoo/research/cv/RCAN/Figs/RCAB.PNG)
剩余通道注意块RCAB结构。
![RCAN](https://gitee.com/bcc2974874275/mindspore/raw/master/model_zoo/research/cv/RCAN/Figs/RCAN.PNG)
本文提出的剩余信道注意网络RCAN的体系结构。
# 数据集
## 使用的数据集:[Div2k](https://data.vision.ee.ethz.ch/cvl/DIV2K/)
- 数据集大小约7.12GB共900张图像
- 训练集800张图像
- 测试集100张图像
- 基准数据集可下载如下:[Set5](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html)、[Set14](https://deepai.org/dataset/set14-super-resolution)、[B100](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/)、[Urban100](http://vllab.ucmerced.edu/wlai24/LapSRN/)。
- 数据格式png文件
- 注数据将在src/data/DIV2K.py中处理。
```bash
DIV2K
├── DIV2K_test_LR_bicubic
│ ├── X2
│ │ ├── 0901x2.png
│ │ ├─ ...
│ │ └── 1000x2.png
│ ├── X3
│ │ ├── 0901x3.png
│ │ ├─ ...
│ │ └── 1000x3.png
│ └── X4
│ ├── 0901x4.png
│ ├─ ...
│ └── 1000x4.png
├── DIV2K_test_LR_unknown
│ ├── X2
│ │ ├── 0901x2.png
│ │ ├─ ...
│ │ └── 1000x2.png
│ ├── X3
│ │ ├── 0901x3.png
│ │ ├─ ...
│ │ └── 1000x3.png
│ └── X4
│ ├── 0901x4.png
│ ├─ ...
│ └── 1000x4.png
├── DIV2K_train_HR
│ ├── 0001.png
│ ├─ ...
│ └── 0900.png
├── DIV2K_train_LR_bicubic
│ ├── X2
│ │ ├── 0001x2.png
│ │ ├─ ...
│ │ └── 0900x2.png
│ ├── X3
│ │ ├── 0001x3.png
│ │ ├─ ...
│ │ └── 0900x3.png
│ └── X4
│ ├── 0001x4.png
│ ├─ ...
│ └── 0900x4.png
└── DIV2K_train_LR_unknown
├── X2
│ ├── 0001x2.png
│ ├─ ...
│ └── 0900x2.png
├── X3
│ ├── 0001x3.png
│ ├─ ...
│ └── 0900x3.png
└── X4
├── 0001x4.png
├─ ...
└── 0900x4.png
```
# 环境要求
- 硬件Ascend
- 准备Ascend处理器搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
# 脚本说明
## 脚本及样例代码
```bash
├── model_zoo
├── README.md // 所有模型相关说明
├── RCAN
├── scripts
│ ├── run_distribute_train.sh // Ascend分布式训练shell脚本
│ ├── run_eval.sh // eval验证shell脚本
│ ├── run_ascend_standalone.sh // Ascend训练shell脚本
├── src
│ ├── data
│ │ ├──common.py //公共数据集
│ │ ├──div2k.py //div2k数据集
│ │ ├──srdata.py //所有数据集
│ ├── rcan_model.py //RCAN网络
│ ├── metrics.py //PSNR,SSIM计算器
│ ├── args.py //超参数
├── train.py //训练脚本
├── eval.py //评估脚本
├── export.py //模型导出
├── README.md // 自述文件
```
## 脚本参数
### 训练
```bash
用法python train.py [--device_target][--dir_data]
[--ckpt_path][--test_every][--task_id]
选项:
--device_target 训练后端类型Ascend默认为Ascend。
--dir_data 数据集存储路径。
--ckpt_path 存放检查点的路径。
--test_every 每N批进行一次试验。
--task_id 任务ID。
```
### 评估
```bash
用法python eval.py [--device_target][--dir_data]
[--task_id][--scale][--data_test]
[--ckpt_save_path]
选项:
--device_target 评估后端类型Ascend。
--dir_data 数据集路径。
--task_id 任务id。
--scale 超分倍数。
--data_test 测试数据集名字。
--ckpt_save_path 检查点路径。
```
## 参数配置
在args.py中可以同时配置训练参数和评估参数。
- RCAN配置div2k数据集
```bash
"lr": 0.0001, # 学习率
"epochs": 500, # 训练轮次数
"batch_size": 16, # 输入张量的批次大小
"weight_decay": 0, # 权重衰减
"loss_scale": 1024, # 损失放大
"buffer_size": 10, # 混洗缓冲区大小
"init_loss_scale":65536, # 比例因子
"betas":(0.9, 0.999), # ADAM beta
"weight_decay":0, # 权重衰减
"num_layers":4, # 层数
"test_every":4000, # 每N批进行一次试验
"n_resgroups":10, # 残差组数
"reduction":16, # 特征映射数减少
"patch_size":48, # 输出块大小
"scale":'2', # 超分辨率比例尺
"task_id":0, # 任务id
"n_colors":3, # 颜色通道数
"n_resblocks":20, # 残差块数
"n_feats":64, # 特诊图数量
"res_scale":1, # residual scaling
```
## 训练过程
### 训练
#### Ascend处理器环境运行RCAN
- 单设备训练1p)
- 二倍超分task_id 0
- 三倍超分task_id 1
- 四倍超分task_id 2
```bash
sh scripts/run_ascend_distribute.sh [TRAIN_DATA_DIR]
```
- 分布式训练
- 二倍超分task_id 0
- 三倍超分task_id 1
- 四倍超分task_id 2
```bash
sh scripts/run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
```
- 分布式训练需要提前创建JSON格式的HCCL配置文件。具体操作参见<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>
## 评估过程
### 评估
- 评估过程如下需要指定数据集类型为“Set5”或“B100”。
```bash
sh scripts/eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
```
- 上述python命令在后台运行可通过`eval.log`文件查看结果。
## 模型导出
```bash
用法python export.py [--batch_size] [--ckpt_path] [--file_format]
选项:
--batch_size 输入张量的批次大小。
--ckpt_path 检查点路径。
--file_format 可选 ['MINDIR', 'AIR', 'ONNX'], 默认['MINDIR']。
```
- FILE_FORMAT 可选 ['MINDIR', 'AIR', 'ONNX'], 默认['MINDIR']。
# 模型描述
## 性能
### 训练性能
| 参数 | RCAN(Ascend) |
| -------------------------- | ---------------------------------------------- |
| 模型版本 | RCAN |
| 资源 | Ascend 910 |
| 上传日期 | 2021-06-30 |
| MindSpore版本 | 1.2.0 |
| 数据集 |DIV2K |
| 训练参数 |epoch=500, batch_size = 16, lr=0.0001 |
| 优化器 | Adam |
| 损失函数 | L1loss |
| 输出 | 超分辨率图片 |
| 损失 | |
| 速度 | 8卡205毫秒/步 |
| 总时长 | 8卡14.74小时 |
| 调优检查点 | 0.2 GB.ckpt 文件) |
| 脚本 |[RCAN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/RCAN) | |
### 评估性能
| 参数 | RCAN(Ascend) |
| ------------------- | --------------------------- |
| 模型版本 | RCAN |
| 资源 | Ascend 910 |
| 上传日期 | 2021-07-11 |
| MindSpore版本 | 1.2.0 |
| 数据集 | Set5,B100 |
| batch_size | 1 |
| 输出 | 超分辨率图片 |
| 准确率 | 单卡Set5: 38.15/B100:32.28 |
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval script"""
import os
import time
import numpy as np
import mindspore.dataset as ds
from mindspore import Tensor, context
from mindspore.common import dtype as mstype
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.args import args
import src.rcan_model as rcan
from src.data.srdata import SRData
from src.metrics import calc_psnr, quantize, calc_ssim
from src.data.div2k import DIV2K
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(max_call_depth=10000)
def eval_net():
"""eval"""
if args.epochs == 0:
args.epochs = 1e8
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False
if args.data_test[0] == 'DIV2K':
train_dataset = DIV2K(args, name=args.data_test, train=False, benchmark=False)
else:
train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False)
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
net_m = rcan.RCAN(args)
if args.ckpt_path:
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(net_m, param_dict)
net_m.set_train(False)
print('load mindspore net successfully.')
num_imgs = train_de_dataset.get_dataset_size()
psnrs = np.zeros((num_imgs, 1))
ssims = np.zeros((num_imgs, 1))
for batch_idx, imgs in enumerate(train_loader):
lr = imgs['LR']
hr = imgs['HR']
lr = Tensor(lr, mstype.float32)
pred = net_m(lr)
pred_np = pred.asnumpy()
pred_np = quantize(pred_np, 255)
psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0)
pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0)
hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0)
ssim = calc_ssim(pred_np, hr, args.scale[0])
print("current psnr: ", psnr)
print("current ssim: ", ssim)
psnrs[batch_idx, 0] = psnr
ssims[batch_idx, 0] = ssim
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0]))
if __name__ == '__main__':
time_start = time.time()
print("Start eval function!")
eval_net()
time_end = time.time()
print('eval_time: %f' % (time_end - time_start))

View File

@ -0,0 +1,47 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export net together with checkpoint into air/mindir/onnx models"""
import os
import argparse
import numpy as np
from src.args import args as arg
from src.rcan_model import RCAN
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
parser = argparse.ArgumentParser(description='rcan export')
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_path", type=str, required=True, help="path of checkpoint file")
parser.add_argument("--file_name", type=str, default="rcan", help="output file name.")
parser.add_argument("--file_format", type=str, default="MINDIR", choices=['MINDIR', 'AIR', 'ONNX'], help="file format")
args_1 = parser.parse_args()
def run_export(args):
""" export """
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
net = RCAN(arg)
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(net, param_dict)
net.set_train(False)
print('load mindspore net and checkpoint successfully.')
inputs = Tensor(np.zeros([args.batch_size, 3, 678, 1020], np.float32))
export(net, inputs, file_name=args.file_name, file_format=args.file_format)
print('export successfully!')
if __name__ == "__main__":
run_export(args_1)

View File

@ -0,0 +1,69 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -f $PATH1 ]; then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]; then
echo "error: TRAIN_DATA_DIR=$PATH2 is not a directory"
exit 1
fi
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log
nohup python train.py \
--batch_size 16 \
--lr 1e-4 \
--scale 2+3+4 \
--task_id 0 \
--dir_data $PATH2 \
--epochs 500 \
--test_every 4000 \
--patch_size 48 > train.log 2>&1 &
cd ..
done

View File

@ -0,0 +1,73 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#!/bin/bash
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]; then
echo "Usage: sh run_standalone_train.sh [TRAIN_DATA_DIR]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
if [ ! -d $PATH1 ]; then
echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory"
exit 1
fi
if [ -d "train" ]; then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp -r ../src ./train
cd ./train || exit
env >env.log
nohup python train.py \
--batch_size 16 \
--lr 1e-4 \
--scale 2+3+4 \
--task_id 0 \
--dir_data $PATH1 \
--epochs 500 \
--test_every 4000 \
--patch_size 48 > train.log 2>&1 &

View File

@ -0,0 +1,62 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]; then
echo "Usage: sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
DATASET_TYPE=$3
if [ ! -d $PATH1 ]; then
echo "error: TEST_DATA_DIR=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]; then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
if [ -d "eval" ]; then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
echo "start evaluation ..."
python eval.py \
--dir_data=${PATH1} \
--batch_size 1 \
--test_only \
--ext "img" \
--data_test=${DATASET_TYPE} \
--ckpt_path=${PATH2} \
--task_id 0 \
--scale 2 > eval.log 2>&1 &

View File

@ -0,0 +1,126 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""args"""
import argparse
import ast
parser = argparse.ArgumentParser(description='RCAN')
# Hardware specifications
parser.add_argument('--seed', type=int, default=1,
help='random seed')
# Data specifications
parser.add_argument('--dir_data', type=str, default='/cache/data/',
help='dataset directory')
parser.add_argument('--data_train', type=str, default='DIV2K',
help='train dataset name')
parser.add_argument('--data_test', type=str, default='DIV2K',
help='test dataset name')
parser.add_argument('--data_range', type=str, default='1-800/801-810',
help='train/test data range')
parser.add_argument('--ext', type=str, default='sep',
help='dataset file extension')
parser.add_argument('--scale', type=str, default='4',
help='super resolution scale')
parser.add_argument('--patch_size', type=int, default=48,
help='output patch size')
parser.add_argument('--rgb_range', type=int, default=255,
help='maximum value of RGB')
parser.add_argument('--n_colors', type=int, default=3,
help='number of color channels to use')
parser.add_argument('--no_augment', action='store_true',
help='do not use data augmentation')
# Model specifications
parser.add_argument('--model', default='RCAN',
help='model name')
parser.add_argument('--act', type=str, default='relu',
help='activation function')
parser.add_argument('--n_resblocks', type=int, default=20,
help='number of residual blocks')
parser.add_argument('--n_feats', type=int, default=64,
help='number of feature maps')
parser.add_argument('--res_scale', type=float, default=1,
help='residual scaling')
# Option for Residual channel attention network (RCAN)
parser.add_argument('--n_resgroups', type=int, default=10,
help='number of residual groups')
parser.add_argument('--reduction', type=int, default=16,
help='number of feature maps reduction')
# Training specifications
parser.add_argument('--test_every', type=int, default=4000,
help='do test per every N batches')
parser.add_argument('--epochs', type=int, default=1000,
help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=16,
help='input batch size for training')
parser.add_argument('--test_only', action='store_true',
help='set this option to test the model')
# Optimization specifications
parser.add_argument('--lr', type=float, default=1e-5,
help='learning rate')
parser.add_argument('--init_loss_scale', type=float, default=65536.,
help='scaling factor')
parser.add_argument('--decay', type=str, default='200',
help='learning rate decay type')
parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
help='ADAM beta')
parser.add_argument('--epsilon', type=float, default=1e-8,
help='ADAM epsilon for numerical stability')
parser.add_argument('--weight_decay', type=float, default=0,
help='weight decay')
parser.add_argument('--gclip', type=float, default=0,
help='gradient clipping threshold (0 = no clipping)')
# ckpt specifications
parser.add_argument('--ckpt_save_path', type=str, default='./ckpt/',
help='path to save ckpt')
parser.add_argument('--ckpt_save_interval', type=int, default=10,
help='save ckpt frequency, unit is epoch')
parser.add_argument('--ckpt_save_max', type=int, default=100,
help='max number of saved ckpt')
parser.add_argument('--ckpt_path', type=str, default='',
help='path of saved ckpt')
# Task
parser.add_argument('--task_id', type=int, default=0)
# ModelArts
parser.add_argument('--modelArts_mode', type=ast.literal_eval, default=False,
help='train on modelarts or not, default is False')
parser.add_argument('--data_url', type=str, default='', help='the directory path of saved file')
args, unparsed = parser.parse_known_args()
args.scale = [int(x) for x in args.scale.split("+")]
args.data_train = args.data_train.split('+')
args.data_test = args.data_test.split('+')
if args.epochs == 0:
args.epochs = 1e8
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False

View File

@ -0,0 +1,97 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""common"""
import random
import os
import numpy as np
def get_patch(*args, patch_size=96, scale=2, input_large=False):
"""get_patch"""
ih, iw = args[0].shape[:2]
tp = patch_size
ip = tp // scale
ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)
if not input_large:
tx, ty = scale * ix, scale * iy
else:
tx, ty = ix, iy
ret = [args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]]
return ret
def set_channel(*args, n_channels=3):
"""set_channel"""
def _set_channel(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
c = img.shape[2]
if n_channels == 3 and c == 1:
img = np.concatenate([img] * n_channels, 2)
return img[:, :, :n_channels]
return [_set_channel(a) for a in args]
def np2Tensor(*args, rgb_range=255):
""" np2Tensor"""
def _np2Tensor(img):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
input_data = np_transpose.astype(np.float32)
output = input_data * (rgb_range / 255)
return output
return [_np2Tensor(a) for a in args]
def augment(*args, hflip=True, rot=True):
"""augment("""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
"""augment"""
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(a) for a in args]
def search(root, target="JPEG"):
"""search"""
item_list = []
items = os.listdir(root)
for item in items:
path = os.path.join(root, item)
if os.path.isdir(path):
item_list.extend(search(path, target))
elif path.split('/')[-1].startswith(target):
item_list.append(path)
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
item_list.append(path)
return item_list

View File

@ -0,0 +1,47 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""div2k"""
import os
from src.data.srdata import SRData
class DIV2K(SRData):
"""DIV2K"""
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
self.dir_hr = None
self.dir_lr = None
data_range = [r.split('-') for r in args.data_range.split('/')]
if train:
data_range = data_range[0]
else:
if args.test_only and len(data_range) == 1:
data_range = data_range[0]
else:
data_range = data_range[1]
self.begin, self.end = list(map(int, data_range))
super(DIV2K, self).__init__(args, name=name, train=train, benchmark=benchmark)
def _scan(self):
names_hr, names_lr = super(DIV2K, self)._scan()
names_hr = names_hr[self.begin - 1:self.end]
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
return names_hr, names_lr
def _set_filesystem(self, dir_data):
super(DIV2K, self)._set_filesystem(dir_data)
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')

View File

@ -0,0 +1,205 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""""srdata"""
import os
import glob
import random
import pickle
import imageio
from src.data import common
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class SRData:
"""srdata"""
def __init__(self, args, name='', train=True, benchmark=False):
self.derain_lr_test = None
self.derain_hr_test = None
self.deblur_lr_test = None
self.deblur_hr_test = None
self.args = args
self.name = name
self.train = train
self.split = 'train' if train else 'test'
self.do_eval = True
self.benchmark = benchmark
self.input_large = (args.model == 'VDSR')
self.scale = args.scale
self.idx_scale = 0
self._set_filesystem(args.dir_data)
self._set_img(args)
if train:
self._repeat(args)
def _set_img(self, args):
"""set_img"""
if args.ext.find('img') < 0:
path_bin = os.path.join(self.apath, 'bin')
os.makedirs(path_bin, exist_ok=True)
list_hr, list_lr = self._scan()
if args.ext.find('img') >= 0 or self.benchmark:
self.images_hr, self.images_lr = list_hr, list_lr
elif args.ext.find('sep') >= 0:
os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True)
for s in self.scale:
if s == 1:
os.makedirs(os.path.join(self.dir_hr), exist_ok=True)
else:
os.makedirs(
os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True)
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
for h in list_hr:
b = h.replace(self.apath, path_bin)
b = b.replace(self.ext[0], '.pt')
self.images_hr.append(b)
self._check_and_load(args.ext, h, b, verbose=True)
for i, ll in enumerate(list_lr):
for l in ll:
b = l.replace(self.apath, path_bin)
b = b.replace(self.ext[1], '.pt')
self.images_lr[i].append(b)
self._check_and_load(args.ext, l, b, verbose=True)
def _repeat(self, args):
"""repeat"""
n_patches = args.batch_size * args.test_every
n_images = len(args.data_train) * len(self.images_hr)
if n_images == 0:
self.repeat = 0
else:
self.repeat = max(n_patches // n_images, 1)
def _scan(self):
"""_scan"""
names_hr = sorted(
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])))
names_lr = [[] for _ in self.scale]
for f in names_hr:
filename, _ = os.path.splitext(os.path.basename(f))
for si, s in enumerate(self.scale):
if s != 1:
scale = s
names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \
.format(s, filename, scale, self.ext[1])))
for si, s in enumerate(self.scale):
if s == 1:
names_lr[si] = names_hr
return names_hr, names_lr
def _set_filesystem(self, dir_data):
"""set_filesystem"""
self.apath = os.path.join(dir_data, self.name[0])
self.dir_hr = os.path.join(self.apath, 'HR')
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
self.ext = ('.png', '.png')
def _check_and_load(self, ext, img, f, verbose=True):
"""check_and_load"""
if not os.path.isfile(f) or ext.find('reset') >= 0:
if verbose:
print('Making a binary: {}'.format(f))
with open(f, 'wb') as _f:
pickle.dump(imageio.imread(img), _f)
# pylint: disable=unused-variable
def __getitem__(self, idx):
"""get item"""
lr, hr, filename = self._load_file(idx)
pair = self.get_patch(lr, hr)
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
return pair_t[0], pair_t[1]
def __len__(self):
"""length of hr"""
if self.train:
return len(self.images_hr) * self.repeat
return len(self.images_hr)
def _get_index(self, idx):
"""get_index"""
if self.train:
return idx % len(self.images_hr)
return idx
def _load_file_hr(self, idx):
"""load_file_hr"""
idx = self._get_index(idx)
f_hr = self.images_hr[idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
if self.args.ext == 'img' or self.benchmark:
hr = imageio.imread(f_hr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f:
hr = pickle.load(_f)
return hr, filename
def _load_file(self, idx):
"""load_file"""
idx = self._get_index(idx)
# print(idx,flush=True)
f_hr = self.images_hr[idx]
f_lr = self.images_lr[self.idx_scale][idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
if self.args.ext == 'img' or self.benchmark:
hr = imageio.imread(f_hr)
lr = imageio.imread(f_lr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f:
hr = pickle.load(_f)
with open(f_lr, 'rb') as _f:
lr = pickle.load(_f)
return lr, hr, filename
def get_patch_hr(self, hr):
"""get_patch_hr"""
if self.train:
hr = self.get_patch_img_hr(hr, patch_size=self.args.patch_size, scale=1)
return hr
def get_patch_img_hr(self, img, patch_size=96, scale=2):
"""get_patch_img_hr"""
ih, iw = img.shape[:2]
tp = patch_size
ip = tp // scale
ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)
ret = img[iy:iy + ip, ix:ix + ip, :]
return ret
def get_patch(self, lr, hr):
"""get_patch"""
scale = self.scale[self.idx_scale]
if self.train:
lr, hr = common.get_patch(
lr, hr,
patch_size=self.args.patch_size * scale,
scale=scale)
if not self.args.no_augment:
lr, hr = common.augment(lr, hr)
else:
ih, iw = lr.shape[:2]
hr = hr[0:ih * scale, 0:iw * scale]
return lr, hr
def set_scale(self, idx_scale):
"""set_scale"""
if not self.input_large:
self.idx_scale = idx_scale
else:
self.idx_scale = random.randint(0, len(self.scale) - 1)

View File

@ -0,0 +1,100 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metrics"""
import math
import numpy as np
import cv2
def quantize(img, rgb_range):
"""quantize image range to 0-255"""
pixel_range = 255 / rgb_range
img = np.multiply(img, pixel_range)
img = np.clip(img, 0, 255)
img = np.round(img) / pixel_range
return img
def calc_psnr(sr, hr, scale, rgb_range):
"""calculate psnr"""
hr = np.float32(hr)
sr = np.float32(sr)
diff = (sr - hr) / rgb_range
gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256
diff = np.multiply(diff, gray_coeffs).sum(1)
if hr.size == 1:
return 0
shave = scale
valid = diff[..., shave:-shave, shave:-shave]
mse = np.mean(pow(valid, 2))
return -10 * math.log10(mse)
def rgb2ycbcr(img, y_only=True):
"""from rgb space to ycbcr space"""
img.astype(np.float32)
if y_only:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
return rlt
def calc_ssim(img1, img2, scale):
"""calculate ssim"""
def ssim(img1, img2):
"""calculate ssim"""
C1 = (0.01 * 255) ** 2
C2 = (0.03 * 255) ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1 ** 2
mu2_sq = mu2 ** 2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
border = scale
img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 256.0 + 16.0
img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 256.0 + 16.0
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1_y = img1_y[border:h - border, border:w - border]
img2_y = img2_y[border:h - border, border:w - border]
if img1_y.ndim == 2:
return ssim(img1_y, img2_y)
if img1.ndim == 3:
if img1.shape[2] == 3:
ssims = []
for _ in range(3):
ssims.append(ssim(img1, img2))
return np.array(ssims).mean()
if img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
raise ValueError('Wrong input image dimensions.')

View File

@ -0,0 +1,228 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""rcan"""
import math
import mindspore.ops as ops
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common import Tensor, Parameter
def default_conv(in_channels, out_channels, kernel_size, has_bias=True):
"""rcan"""
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size // 2), has_bias=has_bias, pad_mode='pad')
class MeanShift(nn.Conv2d):
"""rcan"""
def __init__(self,
rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040),
rgb_std=(1.0, 1.0, 1.0),
sign=-1):
"""rcan"""
super(MeanShift, self).__init__(3, 3, kernel_size=1)
self.reshape = P.Reshape()
self.eye = P.Eye()
std = Tensor(rgb_std, mstype.float32)
self.weight.set_data(
self.reshape(self.eye(3, 3, mstype.float32), (3, 3, 1, 1)) / self.reshape(std, (3, 1, 1, 1)))
self.weight.requires_grad = False
self.bias = Parameter(
sign * rgb_range * Tensor(rgb_mean, mstype.float32) / std, name='bias', requires_grad=False)
self.has_bias = True
def _pixelsf_(x, scale):
"""rcan"""
n, c, ih, iw = x.shape
oh = ih * scale
ow = iw * scale
oc = c // (scale ** 2)
output = P.Transpose()(x, (0, 2, 1, 3))
output = P.Reshape()(output, (n, ih, oc * scale, scale, iw))
output = P.Transpose()(output, (0, 1, 2, 4, 3))
output = P.Reshape()(output, (n, ih, oc, scale, ow))
output = P.Transpose()(output, (0, 2, 1, 3, 4))
output = P.Reshape()(output, (n, oc, oh, ow))
return output
class SmallUpSampler(nn.Cell):
"""rcan"""
def __init__(self, conv, upsize, n_feats, has_bias=True):
"""rcan"""
super(SmallUpSampler, self).__init__()
self.conv = conv(n_feats, upsize * upsize * n_feats, 3, has_bias)
self.reshape = P.Reshape()
self.upsize = upsize
self.pixelsf = _pixelsf_
def construct(self, x):
"""rcan"""
x = self.conv(x)
output = self.pixelsf(x, self.upsize)
return output
class Upsampler(nn.Cell):
"""rcan"""
def __init__(self, conv, scale, n_feats, has_bias=True):
"""rcan"""
super(Upsampler, self).__init__()
m = []
if (scale & (scale - 1)) == 0:
for _ in range(int(math.log(scale, 2))):
m.append(SmallUpSampler(conv, 2, n_feats, has_bias=has_bias))
elif scale == 3:
m.append(SmallUpSampler(conv, 3, n_feats, has_bias=has_bias))
self.net = nn.SequentialCell(m)
def construct(self, x):
"""rcan"""
return self.net(x)
class AdaptiveAvgPool2d(nn.Cell):
"""rcan"""
def __init__(self):
"""rcan"""
super().__init__()
self.ReduceMean = ops.ReduceMean(keep_dims=True)
def construct(self, x):
"""rcan"""
return self.ReduceMean(x, 0)
class CALayer(nn.Cell):
"""rcan"""
def __init__(self, channel, reduction=16):
"""rcan"""
super(CALayer, self).__init__()
# global average pooling: feature --> point
self.avg_pool = AdaptiveAvgPool2d()
# feature channel downscale and upscale --> channel weight
self.conv_du = nn.SequentialCell([
nn.Conv2d(channel, channel // reduction, 1, padding=0, has_bias=True, pad_mode='pad'),
nn.ReLU(),
nn.Conv2d(channel // reduction, channel, 1, padding=0, has_bias=True, pad_mode='pad'),
nn.Sigmoid()
])
def construct(self, x):
"""rcan"""
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
class RCAB(nn.Cell):
"""rcan"""
def __init__(self, conv, n_feat, kernel_size, reduction, has_bias=True
, bn=False, act=nn.ReLU(), res_scale=1):
"""rcan"""
super(RCAB, self).__init__()
self.modules_body = []
for i in range(2):
self.modules_body.append(conv(n_feat, n_feat, kernel_size, has_bias=has_bias))
if bn: self.modules_body.append(nn.BatchNorm2d(n_feat))
if i == 0: self.modules_body.append(act)
self.modules_body.append(CALayer(n_feat, reduction))
self.body = nn.SequentialCell(*self.modules_body)
self.res_scale = res_scale
def construct(self, x):
"""rcan"""
res = self.body(x)
res += x
return res
class ResidualGroup(nn.Cell):
"""rcan"""
def __init__(self, conv, n_feat, kernel_size, reduction, n_resblocks):
"""rcan"""
super(ResidualGroup, self).__init__()
modules_body = []
modules_body = [
RCAB(
conv, n_feat, kernel_size, reduction, has_bias=True, bn=False, act=nn.ReLU(), res_scale=1) \
for _ in range(n_resblocks)]
modules_body.append(conv(n_feat, n_feat, kernel_size))
self.body = nn.SequentialCell(*modules_body)
def construct(self, x):
"""rcan"""
res = self.body(x)
res += x
return res
class RCAN(nn.Cell):
"""rcan"""
def __init__(self, args, conv=default_conv):
"""rcan"""
super(RCAN, self).__init__()
n_resgroups = args.n_resgroups
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
reduction = args.reduction
idx = args.task_id
scale = args.scale[idx]
self.dytpe = mstype.float16
# RGB mean for DIV2K
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std).to_float(self.dytpe)
# define head module
modules_head = conv(args.n_colors, n_feats, kernel_size).to_float(self.dytpe)
# define body module
modules_body = [
ResidualGroup(
conv, n_feats, kernel_size, reduction, n_resblocks=n_resblocks).to_float(self.dytpe) \
for _ in range(n_resgroups)]
modules_body.append(conv(n_feats, n_feats, kernel_size).to_float(self.dytpe))
# define tail module
modules_tail = [
Upsampler(conv, scale, n_feats).to_float(self.dytpe),
conv(n_feats, args.n_colors, kernel_size).to_float(self.dytpe)]
self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1).to_float(self.dytpe)
self.head = modules_head
self.body = nn.SequentialCell(modules_body)
self.tail = nn.SequentialCell(modules_tail)
def construct(self, x):
"""rcan"""
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x

View File

@ -0,0 +1,90 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train"""
import os
import time
from mindspore import context
from mindspore.context import ParallelMode
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
from mindspore.common import set_seed
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from src.args import args
from src.data.div2k import DIV2K
from src.rcan_model import RCAN
def train():
"""train"""
set_seed(1)
device_id = int(os.getenv('DEVICE_ID', '0'))
rank_id = int(os.getenv('RANK_ID', '0'))
device_num = int(os.getenv('RANK_SIZE', '1'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
train_dataset.set_scale(args.task_id)
if args.modelArts_mode:
import moxing as mox
local_data_url = '/cache/data'
if device_num > 1:
init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=device_num, gradients_mean=True)
mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_url)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
train_dataset.set_scale(args.task_id)
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num,
shard_id=rank_id, shuffle=True)
train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
net_m = RCAN(args)
print("Init net weights successfully")
if args.ckpt_path:
param_dict = load_checkpoint(args.pth_path)
load_param_into_net(net_m, param_dict)
print("Load net weight successfully")
step_size = train_de_dataset.get_dataset_size()
lr = []
for i in range(0, args.epochs):
cur_lr = args.lr / (2 ** ((i + 1) // 200))
lr.extend([cur_lr] * step_size)
opt = nn.Adam(net_m.trainable_params(), learning_rate=lr, loss_scale=1024.0)
loss = nn.L1Loss()
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=args.init_loss_scale, \
scale_factor=2, scale_window=1000)
model = Model(net_m, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager)
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
config_ck = CheckpointConfig(save_checkpoint_steps=args.ckpt_save_interval * step_size,
keep_checkpoint_max=args.ckpt_save_max)
ckpt_cb = ModelCheckpoint(prefix="rcan", directory=args.ckpt_save_path, config=config_ck)
if device_id == 0:
cb += [ckpt_cb]
model.train(args.epochs, train_de_dataset, callbacks=cb, dataset_sink_mode=True)
if __name__ == "__main__":
time_start = time.time()
train()
time_end = time.time()
print('train_time: %f' % (time_end - time_start))