add EDSR to modelzoo

This commit is contained in:
smullwy 2021-06-28 19:50:08 +08:00 committed by gengdongjie
parent caf38be2b9
commit ef783f4b06
15 changed files with 1462 additions and 0 deletions

View File

@ -0,0 +1,293 @@
目录
<!-- TOC -->
- [目录](#目录)
- [EDSR描述](#EDSR描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [训练](#训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [模型描述](#模型描述)
- [性能](#性能)
- [训练性能](#训练性能)
- [DIV2K上的EDSR](#DIV2K上的EDSR)
- [评估性能](#评估性能)
- [Set5,Set14,B100,Urban100上的EDSR](#Set5,Set14,B100,Urban100上的EDSR)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# EDSR描述
EDSR是2017年提出的32层深度网络在2017年图像恢复和增强的新趋势研讨会上的超分挑战NTIRE2017 Super-Resolution Challenge中获得第一名。 EDSR相比于SRResNet减少了每个残差块中的batch normalization层,SRResNet相对于原本的ResNet则在每个残差块的出口减去了ReLU层.
[论文](https://arxiv.org/abs/1707.02921)Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **"Enhanced Deep Residual Networks for Single Image Super-Resolution,"** *2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**.
# 模型架构
EDSR先经过1次卷积层,再串联32个残差模块,再经过1次卷积层,最后上采样并卷积。
# 数据集
使用的数据集:[DIV2K](<http://www.vision.ee.ethz.ch/~timofter/publications/Agustsson-CVPRW-2017.pdf>)
- 数据集大小7.11G
- 训练集共800张图像采用了前800张进行训练
- 测试集共100张图像
- 数据格式png文件
- 注数据将在src/data/DIV2K.py中处理。
```shell
DIV2K
├── DIV2K_test_LR_bicubic
│   ├── X2
│   │   ├── 0901x2.png
│ │ ├─ ...
│   │   └── 1000x2.png
│   ├── X3
│   │   ├── 0901x3.png
│ │ ├─ ...
│   │   └── 1000x3.png
│   └── X4
│   ├── 0901x4.png
│ ├─ ...
│   └── 1000x4.png
├── DIV2K_test_LR_unknown
│   ├── X2
│   │   ├── 0901x2.png
│ │ ├─ ...
│   │   └── 1000x2.png
│   ├── X3
│   │   ├── 0901x3.png
│ │ ├─ ...
│   │   └── 1000x3.png
│   └── X4
│   ├── 0901x4.png
│ ├─ ...
│   └── 1000x4.png
├── DIV2K_train_HR
│   ├── 0001.png
│ ├─ ...
│   └── 0900.png
├── DIV2K_train_LR_bicubic
│   ├── X2
│   │   ├── 0001x2.png
│ │ ├─ ...
│   │   └── 0900x2.png
│   ├── X3
│   │   ├── 0001x3.png
│ │ ├─ ...
│   │   └── 0900x3.png
│   └── X4
│   ├── 0001x4.png
│ ├─ ...
│   └── 0900x4.png
└── DIV2K_train_LR_unknown
├── X2
│   ├── 0001x2.png
│ ├─ ...
│   └── 0900x2.png
├── X3
│   ├── 0001x3.png
│ ├─ ...
│   └── 0900x3.png
└── X4
├── 0001x4.png
├─ ...
└── 0900x4.png
```
# 环境要求
- 硬件Ascend
- 使用ascend处理器来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
```shell
#单卡训练
sh run_ascend_standalone.sh [TRAIN_DATA_DIR]
```
```shell
#分布式训练
sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
```
```python
#评估
sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
```
# 脚本说明
## 脚本及样例代码
```bash
├── model_zoo
├── EDSR
├── README_CN.md //自述文件
├── eval.py //评估脚本
├── export.py //导出脚本
├── script
│ ├── run_ascend_distribute.sh //Ascend分布式训练shell脚本
│ ├── run_ascend_standalone.sh //Ascend单卡训练shell脚本
│ └── run_eval.sh //eval验证shell脚本
├── src
│ ├── args.py //超参数
│ ├── common.py //公共网络模块
│ ├── data
│ │ ├── common.py //公共数据集
│ │ ├── div2k.py //div2k数据集
│ │ └── srdata.py //所有数据集
│ ├── metrics.py //PSNR和SSIM计算器
│ ├── model.py //EDSR网络
│ └── utils.py //训练脚本
└── train.py //训练脚本
```
## 脚本参数
主要参数如下:
```python
-h, --help show this help message and exit
--dir_data DIR_DATA dataset directory
--data_train DATA_TRAIN
train dataset name
--data_test DATA_TEST
test dataset name
--data_range DATA_RANGE
train/test data range
--ext EXT dataset file extension
--scale SCALE super resolution scale
--patch_size PATCH_SIZE
output patch size
--rgb_range RGB_RANGE
maximum value of RGB
--n_colors N_COLORS number of color channels to use
--no_augment do not use data augmentation
--model MODEL model name
--n_resblocks N_RESBLOCKS
number of residual blocks
--n_feats N_FEATS number of feature maps
--res_scale RES_SCALE
residual scaling
--test_every TEST_EVERY
do test per every N batches
--epochs EPOCHS number of epochs to train
--batch_size BATCH_SIZE
input batch size for training
--test_only set this option to test the model
--lr LR learning rate
--ckpt_save_path CKPT_SAVE_PATH
path to save ckpt
--ckpt_save_interval CKPT_SAVE_INTERVAL
save ckpt frequency, unit is epoch
--ckpt_save_max CKPT_SAVE_MAX
max number of saved ckpt
--ckpt_path CKPT_PATH
path of saved ckpt
--task_id TASK_ID
```
## 训练过程
### 训练
- Ascend处理器环境运行
```bash
sh run_ascend_standalone.sh [TRAIN_DATA_DIR]
```
如果数据集保存路径为G:\DIV2K`TRAIN_DATA_DIR`应传入G:\。
上述python命令将在后台运行您可以通过train.log文件查看结果。
### 分布式训练
- Ascend处理器环境运行
```bash
sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
```
如果数据集保存路径为G:\DIV2K`TRAIN_DATA_DIR`应传入G:\。
## 评估过程
### 评估
在运行以下命令之前,请检查用于评估的检查点路径。
```bash
sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
```
`DATASET_TYPE`可选 ["Set5", "Set14", "B100", "Urban100", "DIV2K"]
如果数据集保存路径为G:\DIV2K或者G:\Set5或者G:\Set14或者G:\B100或者G:\Urban100`TRAIN_DATA_DIR`应传入G:\。
您可以通过log.txt文件查看结果。
# 模型描述
## 性能
### 训练性能
| 参数 | Ascend |
| ------------- | ------------------------------------------------------------ |
| 资源 | Ascend 910 |
| 上传日期 | 2021-7-4 |
| MindSpore版本 | 1.2.0 |
| 数据集 | DIV2K |
| 训练参数 | epoch=1000, steps=1000, batch_size =16, lr=0.0001 |
| 优化器 | Adam |
| 损失函数 | L1 |
| 输出 | 超分辨率图片 |
| 损失 | 3.1 |
| 速度 | 8卡50.75毫秒/步 |
| 总时长 | 8卡12.865小时 |
| 微调检查点 | 466.13 MB (.ckpt文件) |
| 脚本 | [EDSR](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/EDSR) |
### 评估性能
| 参数 | Ascend |
| ------------- | ----------------------------------------------------------- |
| 资源 | Ascend 910 |
| 上传日期 | 2021-7-4 |
| MindSpore版本 | 1.2.0 |
| 数据集 | Set5,Set14,B100,Urban100 |
| batch_size | 1 |
| 输出 | 超分辨率图片 |
| PSNR | Set5:38.2136, Set14:34.0081, B100:32.3590, Urban100:33.0162 |
# 随机情况说明
在train.py中我们设置了“train_net”函数内的种子。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,75 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""edsr eval script"""
import os
import numpy as np
import mindspore.dataset as ds
from mindspore import Tensor, context
from mindspore.common import dtype as mstype
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.args import args
import src.model as edsr
from src.data.srdata import SRData
from src.data.div2k import DIV2K
from src.metrics import calc_psnr, quantize, calc_ssim
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(max_call_depth=10000)
def eval_net():
"""eval"""
if args.epochs == 0:
args.epochs = 1e8
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False
if args.data_test[0] == 'DIV2K':
train_dataset = DIV2K(args, name=args.data_test, train=False, benchmark=False)
else:
train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False)
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
net_m = edsr.EDSR(args)
if args.ckpt_path:
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(net_m, param_dict)
net_m.set_train(False)
print('load mindspore net successfully.')
num_imgs = train_de_dataset.get_dataset_size()
psnrs = np.zeros((num_imgs, 1))
ssims = np.zeros((num_imgs, 1))
for batch_idx, imgs in enumerate(train_loader):
lr = imgs['LR']
hr = imgs['HR']
lr = Tensor(lr, mstype.float32)
pred = net_m(lr)
pred_np = pred.asnumpy()
pred_np = quantize(pred_np, 255)
psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0)
pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0)
hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0)
ssim = calc_ssim(pred_np, hr, args.scale[0])
print("current psnr: ", psnr)
print("current ssim: ", ssim)
psnrs[batch_idx, 0] = psnr
ssims[batch_idx, 0] = ssim
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0]))
if __name__ == '__main__':
print("Start eval function!")
eval_net()

View File

@ -0,0 +1,61 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export net together with checkpoint into air/mindir/onnx models"""
import os
import argparse
import numpy as np
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
import src.model as edsr
parser = argparse.ArgumentParser(description='edsr export')
parser.add_argument("--ckpt_path", type=str, required=True, help="path of checkpoint file")
parser.add_argument("--file_name", type=str, default="edsr", help="output file name.")
parser.add_argument("--file_format", type=str, default="MINDIR", choices=['MINDIR', 'AIR', 'ONNX'], help="file format")
parser.add_argument('--scale', type=str, default='2', help='super resolution scale')
parser.add_argument('--rgb_range', type=int, default=255, help='maximum value of RGB')
parser.add_argument('--n_colors', type=int, default=3, help='number of color channels to use')
parser.add_argument('--n_resblocks', type=int, default=32, help='number of residual blocks')
parser.add_argument('--n_feats', type=int, default=256, help='number of feature maps')
parser.add_argument('--res_scale', type=float, default=0.1, help='residual scaling')
parser.add_argument('--task_id', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=1)
args1 = parser.parse_args()
args1.scale = [int(x) for x in args1.scale.split("+")]
for arg in vars(args1):
if vars(args1)[arg] == 'True':
vars(args1)[arg] = True
elif vars(args1)[arg] == 'False':
vars(args1)[arg] = False
def run_export(args):
"""run_export"""
device_id = int(os.getenv("DEVICE_ID", '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
net = edsr.EDSR(args)
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(net, param_dict)
net.set_train(False)
print('load mindspore net and checkpoint successfully.')
inputs = Tensor(np.zeros([args.batch_size, 3, 678, 1020], np.float32))
export(net, inputs, file_name=args.file_name, file_format=args.file_format)
print('export successfully!')
if __name__ == "__main__":
run_export(args1)

View File

@ -0,0 +1,72 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -f $PATH1 ]; then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]; then
echo "error: TRAIN_DATA_DIR=$PATH2 is not a directory"
exit 1
fi
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log
python train.py \
--batch_size 2 \
--lr 1e-4 \
--scale 2 \
--task_id 0 \
--dir_data $PATH2 \
--epochs 1000 \
--test_every 8000 \
--n_resblocks 32 \
--n_feats 256 \
--res_scale 0.1 \
--patch_size 48 > train.log 2>&1 &
cd ..
done

View File

@ -0,0 +1,59 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]; then
echo "Usage: sh run_standalone_train.sh [TRAIN_DATA_DIR]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
if [ ! -d $PATH1 ]; then
echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory"
exit 1
fi
if [ -d "train" ]; then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp -r ../src ./train
cd ./train || exit
env >env.log
python train.py \
--batch_size 16 \
--lr 1e-4 \
--scale 2 \
--task_id 0 \
--dir_data $PATH1 \
--epochs 1000 \
--test_every 1000 \
--n_resblocks 32 \
--n_feats 256 \
--res_scale 0.1 \
--patch_size 48 > train.log 2>&1 &

View File

@ -0,0 +1,65 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]; then
echo "Usage: sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
DATASET_TYPE=$3
if [ ! -d $PATH1 ]; then
echo "error: TEST_DATA_DIR=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]; then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
if [ -d "eval" ]; then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
echo "start evaluation ..."
python eval.py \
--dir_data=${PATH1} \
--test_only \
--ext img \
--ckpt_path=${PATH2} \
--task_id 0 \
--scale 2 \
--data_test=${DATASET_TYPE} \
--device_id 0 \
--n_resblocks 32 \
--n_feats 256 \
--res_scale 0.1 > log.txt 2>&1 &

View File

@ -0,0 +1,84 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""args parser"""
import argparse
parser = argparse.ArgumentParser(description='EDSR')
# Data specifications
parser.add_argument('--dir_data', type=str, default='/cache/data/',
help='dataset directory')
parser.add_argument('--data_train', type=str, default='DIV2K',
help='train dataset name')
parser.add_argument('--data_test', type=str, default='DIV2K',
help='test dataset name')
parser.add_argument('--data_range', type=str, default='1-800/801-900',
help='train/test data range')
parser.add_argument('--ext', type=str, default='sep',
help='dataset file extension')
parser.add_argument('--scale', type=str, default='4',
help='super resolution scale')
parser.add_argument('--patch_size', type=int, default=48,
help='input patch size')
parser.add_argument('--rgb_range', type=int, default=255,
help='maximum value of RGB')
parser.add_argument('--n_colors', type=int, default=3,
help='number of color channels to use')
parser.add_argument('--no_augment', action='store_true',
help='do not use data augmentation')
# Model specifications
parser.add_argument('--model', default='EDSR',
help='model name')
parser.add_argument('--n_resblocks', type=int, default=32,
help='number of residual blocks')
parser.add_argument('--n_feats', type=int, default=256,
help='number of feature maps')
parser.add_argument('--res_scale', type=float, default=0.1,
help='residual scaling')
# Training specifications
parser.add_argument('--test_every', type=int, default=8000,
help='do test per every N batches')
parser.add_argument('--epochs', type=int, default=1000,
help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=2,
help='input batch size for training')
parser.add_argument('--test_only', action='store_true',
help='set this option to test the model')
# Optimization specifications
parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate')
parser.add_argument('--loss_scale', type=float, default=1024.0,
help='init loss scale')
# ckpt specifications
parser.add_argument('--ckpt_save_path', type=str, default='./ckpt/',
help='path to save ckpt')
parser.add_argument('--ckpt_save_interval', type=int, default=10,
help='save ckpt frequency, unit is epoch')
parser.add_argument('--ckpt_save_max', type=int, default=5,
help='max number of saved ckpt')
parser.add_argument('--ckpt_path', type=str, default='',
help='path of saved ckpt')
# alltask
parser.add_argument('--task_id', type=int, default=0)
args, unparsed = parser.parse_known_args()
args.scale = [int(x) for x in args.scale.split("+")]
args.data_train = args.data_train.split('+')
args.data_test = args.data_test.split('+')
if args.epochs == 0:
args.epochs = 1e4
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False

View File

@ -0,0 +1,96 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""common"""
import math
import numpy as np
import mindspore
import mindspore.nn as nn
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
pad_mode='pad',
padding=(kernel_size//2), has_bias=bias)
class MeanShift(mindspore.nn.Conv2d):
"""MeanShift"""
def __init__(
self, rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1, dtype=mindspore.float32):
std = mindspore.Tensor(rgb_std, dtype)
weight = mindspore.Tensor(np.eye(3), dtype).reshape(
3, 3, 1, 1) / std.reshape(3, 1, 1, 1)
bias = sign * rgb_range * mindspore.Tensor(rgb_mean, dtype) / std
super(MeanShift, self).__init__(3, 3, kernel_size=1,
has_bias=True, weight_init=weight, bias_init=bias)
for p in self.get_parameters():
p.requires_grad = False
class ResBlock(nn.Cell):
"""ResBlock"""
def __init__(
self, conv, n_feats, kernel_size,
bias=True, act=nn.ReLU(), res_scale=1):
super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
if i == 0:
m.append(act)
self.body = nn.SequentialCell(m)
self.res_scale = res_scale
self.mul = mindspore.ops.Mul()
def construct(self, x):
res = self.body(x)
res = self.mul(res, self.res_scale)
res += x
return res
class PixelShuffle(nn.Cell):
"""PixelShuffle"""
def __init__(self, upscale_factor):
super(PixelShuffle, self).__init__()
self.DepthToSpace = mindspore.ops.DepthToSpace(upscale_factor)
def construct(self, x):
return self.DepthToSpace(x)
def Upsampler(conv, scale, n_feats, bias=True):
"""Upsampler"""
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feats, 4 * n_feats, 3, bias))
m.append(PixelShuffle(2))
elif scale == 3:
m.append(conv(n_feats, 9 * n_feats, 3, bias))
m.append(PixelShuffle(3))
else:
raise NotImplementedError
return m

View File

@ -0,0 +1,75 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""common"""
import random
import numpy as np
def get_patch(*args, patch_size=96, scale=2, input_large=False):
"""common"""
ih, iw = args[0].shape[:2]
tp = patch_size
ip = tp // scale
ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)
if not input_large:
tx, ty = scale * ix, scale * iy
else:
tx, ty = ix, iy
ret = [args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]]
return ret
def set_channel(*args, n_channels=3):
"""common"""
def _set_channel(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
c = img.shape[2]
if n_channels == 3 and c == 1:
img = np.concatenate([img] * n_channels, 2)
return img[:, :, :n_channels]
return [_set_channel(a) for a in args]
def np2Tensor(*args, rgb_range=255):
def _np2Tensor(img):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
input_data = np_transpose.astype(np.float32)
output = input_data * (rgb_range / 255)
return output
return [_np2Tensor(a) for a in args]
def augment(*args, hflip=True, rot=True):
"""common"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
"""common"""
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(a) for a in args]
def search(root, target="JPEG"):
"""srdata"""
item_list = []
items = os.listdir(root)
for item in items:
path = os.path.join(root, item)
if os.path.isdir(path):
item_list.extend(search(path, target))
elif path.split('/')[-1].startswith(target):
item_list.append(path)
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
item_list.append(path)
return item_list

View File

@ -0,0 +1,43 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""div2k"""
import os
from src.data.srdata import SRData
class DIV2K(SRData):
"""DIV2K"""
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
data_range = [r.split('-') for r in args.data_range.split('/')]
if train:
data_range = data_range[0]
else:
if args.test_only and len(data_range) == 1:
data_range = data_range[0]
else:
data_range = data_range[1]
self.begin, self.end = list(map(int, data_range))
super(DIV2K, self).__init__(args, name=name, train=train, benchmark=benchmark)
self.dir_hr = None
self.dir_lr = None
def _scan(self):
names_hr, names_lr = super(DIV2K, self)._scan()
names_hr = names_hr[self.begin - 1:self.end]
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
return names_hr, names_lr
def _set_filesystem(self, dir_data):
super(DIV2K, self)._set_filesystem(dir_data)
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')

View File

@ -0,0 +1,212 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""""srdata"""
import os
import glob
import random
import pickle
import imageio
from src.data import common
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class SRData:
"""srdata"""
def __init__(self, args, name='', train=True, benchmark=False):
self.args = args
self.name = name
self.train = train
self.split = 'train' if train else 'test'
self.do_eval = True
self.benchmark = benchmark
self.input_large = (args.model == 'VDSR')
self.scale = args.scale
self.idx_scale = 0
self._set_filesystem(args.dir_data)
self._set_img(args)
if train:
self._repeat(args)
def _set_img(self, args):
"""srdata"""
if args.ext.find('img') < 0:
path_bin = os.path.join(self.apath, 'bin')
os.makedirs(path_bin, exist_ok=True)
list_hr, list_lr = self._scan()
if args.ext.find('img') >= 0 or self.benchmark:
self.images_hr, self.images_lr = list_hr, list_lr
elif args.ext.find('sep') >= 0:
os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True)
for s in self.scale:
if s == 1:
os.makedirs(os.path.join(self.dir_hr), exist_ok=True)
else:
os.makedirs(
os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True)
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
for h in list_hr:
b = h.replace(self.apath, path_bin)
b = b.replace(self.ext[0], '.pt')
self.images_hr.append(b)
self._check_and_load(args.ext, h, b, verbose=True)
for i, ll in enumerate(list_lr):
for l in ll:
b = l.replace(self.apath, path_bin)
b = b.replace(self.ext[1], '.pt')
self.images_lr[i].append(b)
self._check_and_load(args.ext, l, b, verbose=True)
def _repeat(self, args):
"""srdata"""
n_patches = args.batch_size * args.test_every
n_images = len(args.data_train) * len(self.images_hr)
if n_images == 0:
self.repeat = 0
else:
self.repeat = max(n_patches // n_images, 1)
def _scan(self):
"""srdata"""
names_hr = sorted(
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])))
names_lr = [[] for _ in self.scale]
for f in names_hr:
filename, _ = os.path.splitext(os.path.basename(f))
for si, s in enumerate(self.scale):
if s != 1:
scale = s
names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \
.format(s, filename, scale, self.ext[1])))
for si, s in enumerate(self.scale):
if s == 1:
names_lr[si] = names_hr
return names_hr, names_lr
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, self.name[0])
self.dir_hr = os.path.join(self.apath, 'HR')
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
self.ext = ('.png', '.png')
def _check_and_load(self, ext, img, f, verbose=True):
if not os.path.isfile(f) or ext.find('reset') >= 0:
if verbose:
print('Making a binary: {}'.format(f))
with open(f, 'wb') as _f:
pickle.dump(imageio.imread(img), _f)
def __getitem__(self, idx):
lr, hr, _ = self._load_file(idx)
pair = self.get_patch(lr, hr)
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
return pair_t[0], pair_t[1]
def __len__(self):
if self.train:
return len(self.images_hr) * self.repeat
return len(self.images_hr)
def _get_index(self, idx):
if self.train:
return idx % len(self.images_hr)
return idx
def _load_file_deblur(self, idx, train=True):
"""srdata"""
idx = self._get_index(idx)
if train:
f_hr = self.images_hr[idx]
f_lr = self.images_lr[idx]
else:
f_hr = self.deblur_hr_test[idx]
f_lr = self.deblur_lr_test[idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
filename = f_hr[-27:-17] + filename
hr = imageio.imread(f_hr)
lr = imageio.imread(f_lr)
return lr, hr, filename
def _load_file_hr(self, idx):
"""srdata"""
idx = self._get_index(idx)
f_hr = self.images_hr[idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
if self.args.ext == 'img' or self.benchmark:
hr = imageio.imread(f_hr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f:
hr = pickle.load(_f)
return hr, filename
def _load_rain_test(self, idx):
f_hr = self.derain_hr_test[idx]
f_lr = self.derain_lr_test[idx]
filename, _ = os.path.splitext(os.path.basename(f_lr))
norain = imageio.imread(f_hr)
rain = imageio.imread(f_lr)
return norain, rain, filename
def _load_file(self, idx):
"""srdata"""
idx = self._get_index(idx)
f_hr = self.images_hr[idx]
f_lr = self.images_lr[self.idx_scale][idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
if self.args.ext == 'img' or self.benchmark:
hr = imageio.imread(f_hr)
lr = imageio.imread(f_lr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f:
hr = pickle.load(_f)
with open(f_lr, 'rb') as _f:
lr = pickle.load(_f)
return lr, hr, filename
def get_patch_hr(self, hr):
"""srdata"""
if self.train:
hr = self.get_patch_img_hr(hr, patch_size=self.args.patch_size, scale=1)
return hr
def get_patch_img_hr(self, img, patch_size=96, scale=2):
"""srdata"""
ih, iw = img.shape[:2]
tp = patch_size
ip = tp // scale
ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)
ret = img[iy:iy + ip, ix:ix + ip, :]
return ret
def get_patch(self, lr, hr):
"""srdata"""
scale = self.scale[self.idx_scale]
if self.train:
lr, hr = common.get_patch(
lr, hr,
patch_size=self.args.patch_size * scale,
scale=scale)
if not self.args.no_augment:
lr, hr = common.augment(lr, hr)
else:
ih, iw = lr.shape[:2]
hr = hr[0:ih * scale, 0:iw * scale]
return lr, hr
def set_scale(self, idx_scale):
if not self.input_large:
self.idx_scale = idx_scale
else:
self.idx_scale = random.randint(0, len(self.scale) - 1)

View File

@ -0,0 +1,102 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metrics"""
import math
import numpy as np
import cv2
def quantize(img, rgb_range):
"""quantize image range to 0-255"""
pixel_range = 255 / rgb_range
img = np.multiply(img, pixel_range)
img = np.clip(img, 0, 255)
img = np.round(img) / pixel_range
return img
def calc_psnr(sr, hr, scale, rgb_range):
"""calculate psnr"""
hr = np.float32(hr)
sr = np.float32(sr)
diff = (sr - hr) / rgb_range
gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256
diff = np.multiply(diff, gray_coeffs).sum(1)
if hr.size == 1:
return 0
if scale != 1:
shave = scale
else:
shave = scale + 6
if scale == 1:
valid = diff
else:
valid = diff[..., shave:-shave, shave:-shave]
mse = np.mean(pow(valid, 2))
return -10 * math.log10(mse)
def rgb2ycbcr(img, y_only=True):
"""from rgb space to ycbcr space"""
img.astype(np.float32)
if y_only:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
return rlt
def calc_ssim(img1, img2, scale):
"""calculate ssim value"""
def ssim(img1, img2):
C1 = (0.01 * 255) ** 2
C2 = (0.03 * 255) ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1 ** 2
mu2_sq = mu2 ** 2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
border = 0
if scale != 1:
border = scale
else:
border = scale + 6
img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 256.0 + 16.0
img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 256.0 + 16.0
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1_y = img1_y[border:h - border, border:w - border]
img2_y = img2_y[border:h - border, border:w - border]
if img1_y.ndim == 2:
return ssim(img1_y, img2_y)
if img1.ndim == 3:
if img1.shape[2] == 3:
ssims = []
for _ in range(3):
ssims.append(ssim(img1, img2))
return np.array(ssims).mean()
if img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError('Wrong input image dimensions.')

View File

@ -0,0 +1,65 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""edsr_model"""
import mindspore.nn as nn
from src import common
class EDSR(nn.Cell):
"""EDSR"""
def __init__(self, args, conv=common.default_conv):
super(EDSR, self).__init__()
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
scale = args.scale[0]
act = nn.ReLU()
self.sub_mean = common.MeanShift(args.rgb_range)
self.add_mean = common.MeanShift(args.rgb_range, sign=1)
# define head module
m_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
m_body = [
common.ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
) for _ in range(n_resblocks)
]
m_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
m_tail = []
m_tail += common.Upsampler(conv, scale, n_feats)
m_tail.append(conv(n_feats, args.n_colors, kernel_size))
self.head = nn.SequentialCell(m_head)
self.body = nn.SequentialCell(m_body)
self.tail = nn.SequentialCell(m_tail)
def construct(self, x):
"""construct"""
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x

View File

@ -0,0 +1,87 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""edsr train wrapper"""
import os
import time
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.train.serialization import save_checkpoint
class Trainer():
"""Trainer"""
def __init__(self, args, loader, my_model):
self.args = args
self.scale = args.scale
self.trainloader = loader
self.model = my_model
self.model.set_train()
self.criterion = nn.L1Loss()
self.loss_history = []
self.begin_time = time.time()
self.optimizer = nn.Adam(self.model.trainable_params(), learning_rate=args.lr, loss_scale=1024.0)
self.loss_net = nn.WithLossCell(self.model, self.criterion)
self.net = nn.TrainOneStepCell(self.loss_net, self.optimizer)
def train(self, epoch):
"""Trainer"""
losses = 0
batch_idx = 0
for batch_idx, imgs in enumerate(self.trainloader):
lr = imgs["LR"]
hr = imgs["HR"]
lr = Tensor(lr, mstype.float32)
hr = Tensor(hr, mstype.float32)
t1 = time.time()
loss = self.net(lr, hr)
t2 = time.time()
losses += loss.asnumpy()
print('Epoch: %g, Step: %g , loss: %f, time: %f s ' % \
(epoch, batch_idx, loss.asnumpy(), t2 - t1), end='\n', flush=True)
print("the epoch loss is", losses / (batch_idx + 1), flush=True)
self.loss_history.append(losses / (batch_idx + 1))
print(self.loss_history)
t = time.time() - self.begin_time
t = int(t)
print(", running time: %gh%g'%g''"%(t//3600, (t-t//3600*3600)//60, t%60), flush=True)
os.makedirs(self.args.save, exist_ok=True)
if self.args.rank == 0 and (epoch+1)%10 == 0:
save_checkpoint(self.net, self.args.save + "model_" + str(self.epoch) + '.ckpt')
def update_learning_rate(self, epoch):
"""Update learning rates for all the networks; called at the end of every epoch.
:param epoch: current epoch
:type epoch: int
:param lr: learning rate of cyclegan
:type lr: float
:param niter: number of epochs with the initial learning rate
:type niter: int
:param niter_decay: number of epochs to linearly decay learning rate to zero
:type niter_decay: int
"""
self.epoch = epoch
print("*********** epoch: {} **********".format(epoch))
lr = self.args.lr / (2 ** ((epoch+1)//200))
self.adjust_lr('model', self.optimizer, lr)
print("*********************************")
def adjust_lr(self, name, optimizer, lr):
"""Adjust learning rate for the corresponding model.
:param name: name of model
:type name: str
:param optimizer: the optimizer of the corresponding model
:type optimizer: torch.optim
:param lr: learning rate to be adjusted
:type lr: float
"""
lr_param = optimizer.get_lr()
lr_param.assign_value(Tensor(lr, mstype.float32))
print('==> ' + name + ' learning rate: ', lr_param.asnumpy())

View File

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""edsr train script"""
import os
from mindspore import context
from mindspore import dataset as ds
import mindspore.nn as nn
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.common import set_seed
from mindspore.train.model import Model
from src.args import args
from src.data.div2k import DIV2K
from src.model import EDSR
def train_net():
"""train edsr"""
set_seed(1)
device_id = int(os.getenv('DEVICE_ID', '0'))
rank_id = int(os.getenv('RANK_ID', '0'))
device_num = int(os.getenv('RANK_SIZE', '1'))
# if distribute:
if device_num > 1:
init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=device_num, gradients_mean=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
train_dataset.set_scale(args.task_id)
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num,
shard_id=rank_id, shuffle=True)
train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
net_m = EDSR(args)
print("Init net successfully")
if args.ckpt_path:
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(net_m, param_dict)
print("Load net weight successfully")
step_size = train_de_dataset.get_dataset_size()
lr = []
for i in range(0, args.epochs):
cur_lr = args.lr / (2 ** ((i + 1)//200))
lr.extend([cur_lr] * step_size)
opt = nn.Adam(net_m.trainable_params(), learning_rate=lr, loss_scale=args.loss_scale)
loss = nn.L1Loss()
model = Model(net_m, loss_fn=loss, optimizer=opt)
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
config_ck = CheckpointConfig(save_checkpoint_steps=args.ckpt_save_interval * step_size,
keep_checkpoint_max=args.ckpt_save_max)
ckpt_cb = ModelCheckpoint(prefix="edsr", directory=args.ckpt_save_path, config=config_ck)
if device_id == 0:
cb += [ckpt_cb]
model.train(args.epochs, train_de_dataset, callbacks=cb, dataset_sink_mode=True)
if __name__ == "__main__":
train_net()