!18924 add wdsr to model zoo

Merge pull request !18924 from Humanyue/master
This commit is contained in:
i-robot 2021-07-13 11:39:53 +00:00 committed by Gitee
commit 303c354c49
14 changed files with 1430 additions and 0 deletions

View File

@ -0,0 +1,297 @@
目录
<!-- TOC -->
- [目录](#目录)
- [WDSR描述](#WDSR描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [训练](#训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [模型导出](#模型导出)
- [模型描述](#模型描述)
- [性能](#性能)
- [训练性能](#训练性能)
- [评估性能](#评估性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# WDSR描述
WDSR于2018年提出的 WDSR用于提高深度超分辨率网络的精度它在 NTIRE 2018 单幅图像超分辨率挑战赛中获得了所有三个真实赛道的第一名。
[论文1](https://arxiv.org/abs/1707.02921)Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **"Enhanced Deep Residual Networks for Single Image Super-Resolution,"** *2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**.[论文2](https://arxiv.org/abs/1808.08718): Jiahui Yu, Yuchen Fan, Jianchao Yang, Ning Xu, Zhaowen Wang, Xinchao Wang, Thomas Huang, **"Wide Activation for Efficient and Accurate Image Super-Resolution"**, arXiv preprint arXiv:1808.08718.
# 模型架构
WDSR网络主要由几个基本模块包括卷积层和池化层组成。通过更广泛的激活和线性低秩卷积并引入权重归一化实现更高精度的单幅图像超分辨率。这里的基本模块主要包括以下基本操作 **1 × 1 卷积**和**3 × 3 卷积**
经过1次卷积层,再串联32个残差模块,再经过1次卷积层,最后上采样并卷积。
# 数据集
使用的数据集:[DIV2K](<http://www.vision.ee.ethz.ch/~timofter/publications/Agustsson-CVPRW-2017.pdf>)
- 数据集大小7.11G
- 训练集共900张图像
- 测试集共100张图像
- 数据格式png文件
- 注数据将在src/data/DIV2K.py中处理。
```shell
DIV2K
├── DIV2K_test_LR_bicubic
│   ├── X2
│   │   ├── 0901x2.png
│ │ ├─ ...
│   │   └── 1000x2.png
│   ├── X3
│   │   ├── 0901x3.png
│ │ ├─ ...
│   │   └── 1000x3.png
│   └── X4
│   ├── 0901x4.png
│ ├─ ...
│   └── 1000x4.png
├── DIV2K_test_LR_unknown
│   ├── X2
│   │   ├── 0901x2.png
│ │ ├─ ...
│   │   └── 1000x2.png
│   ├── X3
│   │   ├── 0901x3.png
│ │ ├─ ...
│   │   └── 1000x3.png
│   └── X4
│   ├── 0901x4.png
│ ├─ ...
│   └── 1000x4.png
├── DIV2K_train_HR
│   ├── 0001.png
│ ├─ ...
│   └── 0900.png
├── DIV2K_train_LR_bicubic
│   ├── X2
│   │   ├── 0001x2.png
│ │ ├─ ...
│   │   └── 0900x2.png
│   ├── X3
│   │   ├── 0001x3.png
│ │ ├─ ...
│   │   └── 0900x3.png
│   └── X4
│   ├── 0001x4.png
│ ├─ ...
│   └── 0900x4.png
└── DIV2K_train_LR_unknown
├── X2
│   ├── 0001x2.png
│ ├─ ...
│   └── 0900x2.png
├── X3
│   ├── 0001x3.png
│ ├─ ...
│   └── 0900x3.png
└── X4
├── 0001x4.png
├─ ...
└── 0900x4.png
```
# 环境要求
- 硬件Ascend
- 使用ascend处理器来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
```shell
#单卡训练
sh run_ascend_standalone.sh [TRAIN_DATA_DIR]
```
```shell
#分布式训练
sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
```
```python
#评估
sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
```
# 脚本说明
## 脚本及样例代码
```bash
WDSR
├── README_CN.md //自述文件
├── eval.py //评估脚本
├── export.py
├── script
│   ├── run_ascend_distribute.sh //Ascend分布式训练shell脚本
│   ├── run_ascend_standalone.sh //Ascend单卡训练shell脚本
│   └── run_eval.sh //eval验证shell脚本
├── src
│   ├── args.py //超参数
│   ├── common.py //公共网络模块
│   ├── data
│   │   ├── common.py //公共数据集
│   │   ├── div2k.py //div2k数据集
│   │   └── srdata.py //所有数据集
│   ├── metrics.py //PSNR和SSIM计算器
│   ├── model.py //WDSR网络
│   └── utils.py //训练脚本
└── train.py //训练脚本
```
## 脚本参数
主要参数如下:
```python
-h, --help show this help message and exit
--dir_data DIR_DATA dataset directory
--data_train DATA_TRAIN
train dataset name
--data_test DATA_TEST
test dataset name
--data_range DATA_RANGE
train/test data range
--ext EXT dataset file extension
--scale SCALE super resolution scale
--patch_size PATCH_SIZE
output patch size
--rgb_range RGB_RANGE
maximum value of RGB
--n_colors N_COLORS number of color channels to use
--no_augment do not use data augmentation
--model MODEL model name
--n_resblocks N_RESBLOCKS
number of residual blocks
--n_feats N_FEATS number of feature maps
--res_scale RES_SCALE
residual scaling
--test_every TEST_EVERY
do test per every N batches
--epochs EPOCHS number of epochs to train
--batch_size BATCH_SIZE
input batch size for training
--test_only set this option to test the model
--lr LR learning rate
--ckpt_save_path CKPT_SAVE_PATH
path to save ckpt
--ckpt_save_interval CKPT_SAVE_INTERVAL
save ckpt frequency, unit is epoch
--ckpt_save_max CKPT_SAVE_MAX
max number of saved ckpt
--ckpt_path CKPT_PATH
path of saved ckpt
--task_id TASK_ID
```
## 训练过程
### 训练
- Ascend处理器环境运行
```bash
sh run_ascend_standalone.sh [TRAIN_DATA_DIR]
```
上述python命令将在后台运行您可以通过train.log文件查看结果。
### 分布式训练
- Ascend处理器环境运行
```bash
sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
```
TRAIN_DATA_DIR = "~DATA/"。
## 评估过程
### 评估
在运行以下命令之前,请检查用于评估的检查点路径。
```bash
sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] DIV2K
```
TEST_DATA_DIR = "~DATA/"。
您可以通过eval.log文件查看结果。
# 模型导出
```bash
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
```
FILE_FORMAT 可选 ['MINDIR', 'AIR', 'ONNX'], 默认['MINDIR']。
# 模型描述
## 性能
### 训练性能
| 参数 | Ascend |
| ------------- | ------------------------------------------------------------ |
| 资源 | Ascend 910 |
| 上传日期 | 2021-7-4 |
| MindSpore版本 | 1.2.0 |
| 数据集 | DIV2K |
| 训练参数 | epoch=1000, steps=100, batch_size =16, lr=0.0001 |
| 优化器 | Adam |
| 损失函数 | L1 |
| 输出 | 超分辨率图片 |
| 损失 | 3.5 |
| 速度 | 8卡约130毫秒/步 |
| 总时长 | 8卡0.5小时 |
| 微调检查点 | 35 MB(.ckpt文件) |
| 脚本 | [WDSR](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/wdsr) |
### 评估性能
| 参数 | Ascend |
| ------------- | ----------------------------------------------------------- |
| 资源 | Ascend 910 |
| 上传日期 | 2021-7-4 |
| MindSpore版本 | 1.2.0 |
| 数据集 | DIV2K |
| batch_size | 1 |
| 输出 | 超分辨率图片 |
| PSNR | DIV2K 34.7780 |
# 随机情况说明
在train.py中我们设置了“train_net”函数内的种子。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,74 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""wdsr eval script"""
import os
import numpy as np
import mindspore.dataset as ds
from mindspore import Tensor, context
from mindspore.common import dtype as mstype
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.args import args
import src.model as wdsr
from src.data.srdata import SRData
from src.data.div2k import DIV2K
from src.metrics import calc_psnr, quantize, calc_ssim
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(max_call_depth=10000)
def eval_net():
"""eval"""
if args.epochs == 0:
args.epochs = 1e8
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False
if args.data_test[0] == 'DIV2K':
train_dataset = DIV2K(args, name=args.data_test, train=False, benchmark=False)
else:
train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False)
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
net_m = wdsr.WDSR()
if args.ckpt_path:
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(net_m, param_dict)
net_m.set_train(False)
print('load mindspore net successfully.')
num_imgs = train_de_dataset.get_dataset_size()
psnrs = np.zeros((num_imgs, 1))
ssims = np.zeros((num_imgs, 1))
for batch_idx, imgs in enumerate(train_loader):
lr = imgs['LR']
hr = imgs['HR']
lr = Tensor(lr, mstype.float32)
pred = net_m(lr)
pred_np = pred.asnumpy()
pred_np = quantize(pred_np, 255)
psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0)
pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0)
hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0)
ssim = calc_ssim(pred_np, hr, args.scale[0])
print("current psnr: ", psnr)
print("current ssim: ", ssim)
psnrs[batch_idx, 0] = psnr
ssims[batch_idx, 0] = ssim
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0]))
if __name__ == '__main__':
print("Start eval function!")
eval_net()

View File

@ -0,0 +1,43 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export net together with checkpoint into air/mindir/onnx models"""
import os
import argparse
import numpy as np
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
import src.model as wdsr
parser = argparse.ArgumentParser(description='wdsr export')
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_path", type=str, required=True, help="path of checkpoint file")
parser.add_argument("--file_name", type=str, default="wdsr", help="output file name.")
parser.add_argument("--file_format", type=str, default="MINDIR", choices=['MINDIR', 'AIR', 'ONNX'], help="file format")
args = parser.parse_args()
def run_export(args1):
"""run_export"""
device_id = int(os.getenv("DEVICE_ID", '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
net = wdsr.WDSR()
param_dict = load_checkpoint(args1.ckpt_path)
load_param_into_net(net, param_dict)
net.set_train(False)
print('load mindspore net and checkpoint successfully.')
inputs = Tensor(np.zeros([args1.batch_size, 3, 678, 1020], np.float32))
export(net, inputs, file_name=args.file_name, file_format=args.file_format)
print('export successfully!')
if __name__ == "__main__":
run_export(args)

View File

@ -0,0 +1,69 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -f $PATH1 ]; then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]; then
echo "error: TRAIN_DATA_DIR=$PATH2 is not a directory"
exit 1
fi
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log
nohup python train.py \
--batch_size 16 \
--lr 1e-4 \
--scale 2 \
--task_id 0 \
--dir_data $PATH2 \
--epochs 1000 \
--test_every 100 \
--patch_size 48 > train.log 2>&1 &
cd ..
done

View File

@ -0,0 +1,72 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#!/bin/bash
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]; then
echo "Usage: sh run_standalone_train.sh [TRAIN_DATA_DIR]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
if [ ! -d $PATH1 ]; then
echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory"
exit 1
fi
if [ -d "train" ]; then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp -r ../src ./train
cd ./train || exit
env >env.log
nohup python train.py \
--batch_size 16 \
--lr 1e-4 \
--scale 2 \
--task_id 0 \
--dir_data $PATH1 \
--epochs 1000 \
--test_every 100 \
--patch_size 48 > train.log 2>&1 &

View File

@ -0,0 +1,61 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]; then
echo "Usage: sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
DATASET_TYPE=$3
if [ ! -d $PATH1 ]; then
echo "error: TEST_DATA_DIR=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]; then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
if [ -d "eval" ]; then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
echo "start evaluation ..."
python eval.py \
--dir_data=${PATH1} \
--test_only \
--ext "img" \
--data_test=${DATASET_TYPE} \
--ckpt_path=${PATH2} \
--task_id 0 \
--scale 2 > eval.log 2>&1 &

View File

@ -0,0 +1,102 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""args parser"""
import argparse
parser = argparse.ArgumentParser(description='EDSR and MDSR')
# Data specifications
parser.add_argument('--dir_data', type=str, default='/cache/data/',
help='dataset directory')
parser.add_argument('--data_train', type=str, default='DIV2K',
help='train dataset name')
parser.add_argument('--data_test', type=str, default='DIV2K',
help='test dataset name')
parser.add_argument('--data_range', type=str, default='1-800/801-900',
help='train/test data range')
parser.add_argument('--ext', type=str, default='sep',
help='dataset file extension')
parser.add_argument('--scale', type=str, default='4',
help='super resolution scale')
parser.add_argument('--patch_size', type=int, default=48,
help='output patch size')
parser.add_argument('--rgb_range', type=int, default=255,
help='maximum value of RGB')
parser.add_argument('--n_colors', type=int, default=3,
help='number of color channels to use')
parser.add_argument('--no_augment', action='store_true',
help='do not use data augmentation')
# Model specifications
parser.add_argument('--model', default='WDSR',
help='model name')
parser.add_argument('--n_resblocks', type=int, default=16,
help='number of residual blocks')
parser.add_argument('--n_feats', type=int, default=64,
help='number of feature maps')
parser.add_argument('--res_scale', type=float, default=1,
help='residual scaling')
# Training specifications
parser.add_argument('--test_every', type=int, default=1000,
help='do test per every N batches')
parser.add_argument('--epochs', type=int, default=300,
help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=16,
help='input batch size for training')
#parser.add_argument('--self_ensemble', action='store_true',
# help='use self-ensemble method for test')
parser.add_argument('--test_only', action='store_true',
help='set this option to test the model')
# Optimization specifications
parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate')
parser.add_argument('--init_loss_scale', type=float, default=65536.,
help='scaling factor')
parser.add_argument('--decay', type=str, default='200',
help='learning rate decay type')
parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
help='ADAM beta')
parser.add_argument('--epsilon', type=float, default=1e-8,
help='ADAM epsilon for numerical stability')
parser.add_argument('--weight_decay', type=float, default=0,
help='weight decay')
parser.add_argument('--gclip', type=float, default=0,
help='gradient clipping threshold (0 = no clipping)')
# ckpt specifications
parser.add_argument('--ckpt_save_path', type=str, default='./ckpt/',
help='path to save ckpt')
parser.add_argument('--ckpt_save_interval', type=int, default=10,
help='save ckpt frequency, unit is epoch')
parser.add_argument('--ckpt_save_max', type=int, default=5,
help='max number of saved ckpt')
parser.add_argument('--ckpt_path', type=str, default='',
help='path of saved ckpt')
# alltask
parser.add_argument('--task_id', type=int, default=0)
# rgb_mean
parser.add_argument('--r_mean', type=float, default=0.4488,
help='r_mean')
parser.add_argument('--g_mean', type=float, default=0.4371,
help='g_mean')
parser.add_argument('--b_mean', type=float, default=0.4040,
help='b_mean')
args, unparsed = parser.parse_known_args()
args.scale = [int(x) for x in args.scale.split("+")]
args.data_train = args.data_train.split('+')
args.data_test = args.data_test.split('+')
if args.epochs == 0:
args.epochs = 1e4
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False

View File

@ -0,0 +1,78 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""common"""
import random
import numpy as np
def get_patch(*args, patch_size=96, scale=2, input_large=False):
"""common"""
ih, iw = args[0].shape[:2]
tp = patch_size
ip = tp // scale
# print("tp=%g,scale=%g,"%(tp,scale),end='',flush=True)
# print("ih=%g,iw=%g,ip=%g"%(ih,iw,ip),flush=True)
ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)
# ix = iy = 0
if not input_large:
tx, ty = scale * ix, scale * iy
else:
tx, ty = ix, iy
ret = [args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]]
return ret
def set_channel(*args, n_channels=3):
"""common"""
def _set_channel(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
c = img.shape[2]
if n_channels == 3 and c == 1:
img = np.concatenate([img] * n_channels, 2)
return img[:, :, :n_channels]
return [_set_channel(a) for a in args]
def np2Tensor(*args, rgb_range=255):
def _np2Tensor(img):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
input_data = np_transpose.astype(np.float32)
output = input_data * (rgb_range / 255)
return output
return [_np2Tensor(a) for a in args]
def augment(*args, hflip=True, rot=True):
"""common"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
"""common"""
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(a) for a in args]
def search(root, target="JPEG"):
"""srdata"""
item_list = []
items = os.listdir(root)
for item in items:
path = os.path.join(root, item)
if os.path.isdir(path):
item_list.extend(search(path, target))
elif path.split('/')[-1].startswith(target):
item_list.append(path)
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
item_list.append(path)
return item_list

View File

@ -0,0 +1,43 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""div2k"""
import os
from src.data.srdata import SRData
class DIV2K(SRData):
"""DIV2K"""
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
data_range = [r.split('-') for r in args.data_range.split('/')]
if train:
data_range = data_range[0]
else:
if args.test_only and len(data_range) == 1:
data_range = data_range[0]
else:
data_range = data_range[1]
self.begin, self.end = list(map(int, data_range))
super(DIV2K, self).__init__(args, name=name, train=train, benchmark=benchmark)
self.dir_hr = None
self.dir_lr = None
def _scan(self):
names_hr, names_lr = super(DIV2K, self)._scan()
names_hr = names_hr[self.begin - 1:self.end]
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
return names_hr, names_lr
def _set_filesystem(self, dir_data):
super(DIV2K, self)._set_filesystem(dir_data)
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')

View File

@ -0,0 +1,215 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""""srdata"""
import os
import glob
import random
import pickle
import imageio
from src.data import common
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class SRData:
"""srdata"""
def __init__(self, args, name='', train=True, benchmark=False):
self.args = args
self.name = name
self.train = train
self.split = 'train' if train else 'test'
self.do_eval = True
self.benchmark = benchmark
self.input_large = (args.model == 'VDSR')
self.scale = args.scale
self.idx_scale = 0
self._set_filesystem(args.dir_data)
self._set_img(args)
if train:
self._repeat(args)
def _set_img(self, args):
"""srdata"""
if args.ext.find('img') < 0:
path_bin = os.path.join(self.apath, 'bin')
os.makedirs(path_bin, exist_ok=True)
list_hr, list_lr = self._scan()
if args.ext.find('img') >= 0 or self.benchmark:
self.images_hr, self.images_lr = list_hr, list_lr
elif args.ext.find('sep') >= 0:
os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True)
for s in self.scale:
if s == 1:
os.makedirs(os.path.join(self.dir_hr), exist_ok=True)
else:
os.makedirs(
os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True)
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
for h in list_hr:
b = h.replace(self.apath, path_bin)
b = b.replace(self.ext[0], '.pt')
self.images_hr.append(b)
self._check_and_load(args.ext, h, b, verbose=True)
for i, ll in enumerate(list_lr):
for l in ll:
b = l.replace(self.apath, path_bin)
b = b.replace(self.ext[1], '.pt')
self.images_lr[i].append(b)
self._check_and_load(args.ext, l, b, verbose=True)
def _repeat(self, args):
"""srdata"""
n_patches = args.batch_size * args.test_every
n_images = len(args.data_train) * len(self.images_hr)
if n_images == 0:
self.repeat = 0
else:
self.repeat = max(n_patches // n_images, 1)
def _scan(self):
"""srdata"""
names_hr = sorted(
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])))
names_lr = [[] for _ in self.scale]
for f in names_hr:
filename, _ = os.path.splitext(os.path.basename(f))
for si, s in enumerate(self.scale):
if s != 1:
scale = s
names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \
.format(s, filename, scale, self.ext[1])))
for si, s in enumerate(self.scale):
if s == 1:
names_lr[si] = names_hr
return names_hr, names_lr
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, self.name[0])
self.dir_hr = os.path.join(self.apath, 'HR')
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
self.ext = ('.png', '.png')
def _check_and_load(self, ext, img, f, verbose=True):
if not os.path.isfile(f) or ext.find('reset') >= 0:
if verbose:
print('Making a binary: {}'.format(f))
with open(f, 'wb') as _f:
pickle.dump(imageio.imread(img), _f)
# pylint: disable=unused-variable
def __getitem__(self, idx):
lr, hr, filename = self._load_file(idx)
pair = self.get_patch(lr, hr)
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
#return pair_t[0], pair_t[1], [self.idx_scale], [filename]
return pair_t[0], pair_t[1]
def __len__(self):
if self.train:
return len(self.images_hr) * self.repeat
return len(self.images_hr)
def _get_index(self, idx):
if self.train:
return idx % len(self.images_hr)
return idx
def _load_file_deblur(self, idx, train=True):
"""srdata"""
idx = self._get_index(idx)
if train:
f_hr = self.images_hr[idx]
f_lr = self.images_lr[idx]
else:
f_hr = self.deblur_hr_test[idx]
f_lr = self.deblur_lr_test[idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
filename = f_hr[-27:-17] + filename
hr = imageio.imread(f_hr)
lr = imageio.imread(f_lr)
return lr, hr, filename
def _load_file_hr(self, idx):
"""srdata"""
idx = self._get_index(idx)
f_hr = self.images_hr[idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
if self.args.ext == 'img' or self.benchmark:
hr = imageio.imread(f_hr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f:
hr = pickle.load(_f)
return hr, filename
def _load_rain_test(self, idx):
f_hr = self.derain_hr_test[idx]
f_lr = self.derain_lr_test[idx]
filename, _ = os.path.splitext(os.path.basename(f_lr))
norain = imageio.imread(f_hr)
rain = imageio.imread(f_lr)
return norain, rain, filename
def _load_file(self, idx):
"""srdata"""
idx = self._get_index(idx)
# print(idx,flush=True)
f_hr = self.images_hr[idx]
f_lr = self.images_lr[self.idx_scale][idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
if self.args.ext == 'img' or self.benchmark:
hr = imageio.imread(f_hr)
lr = imageio.imread(f_lr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f:
hr = pickle.load(_f)
with open(f_lr, 'rb') as _f:
lr = pickle.load(_f)
return lr, hr, filename
def get_patch_hr(self, hr):
"""srdata"""
if self.train:
hr = self.get_patch_img_hr(hr, patch_size=self.args.patch_size, scale=1)
return hr
def get_patch_img_hr(self, img, patch_size=96, scale=2):
"""srdata"""
ih, iw = img.shape[:2]
tp = patch_size
ip = tp // scale
ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)
ret = img[iy:iy + ip, ix:ix + ip, :]
return ret
def get_patch(self, lr, hr):
"""srdata"""
scale = self.scale[self.idx_scale]
if self.train:
lr, hr = common.get_patch(
lr, hr,
patch_size=self.args.patch_size * scale,
scale=scale)
if not self.args.no_augment:
lr, hr = common.augment(lr, hr)
else:
ih, iw = lr.shape[:2]
hr = hr[0:ih * scale, 0:iw * scale]
return lr, hr
def set_scale(self, idx_scale):
if not self.input_large:
self.idx_scale = idx_scale
else:
self.idx_scale = random.randint(0, len(self.scale) - 1)

View File

@ -0,0 +1,102 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metrics"""
import math
import numpy as np
import cv2
def quantize(img, rgb_range):
"""quantize image range to 0-255"""
pixel_range = 255 / rgb_range
img = np.multiply(img, pixel_range)
img = np.clip(img, 0, 255)
img = np.round(img) / pixel_range
return img
def calc_psnr(sr, hr, scale, rgb_range):
"""calculate psnr"""
hr = np.float32(hr)
sr = np.float32(sr)
diff = (sr - hr) / rgb_range
gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256
diff = np.multiply(diff, gray_coeffs).sum(1)
if hr.size == 1:
return 0
if scale != 1:
shave = scale
else:
shave = scale + 6
if scale == 1:
valid = diff
else:
valid = diff[..., shave:-shave, shave:-shave]
mse = np.mean(pow(valid, 2))
return -10 * math.log10(mse)
def rgb2ycbcr(img, y_only=True):
"""from rgb space to ycbcr space"""
img.astype(np.float32)
if y_only:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
return rlt
def calc_ssim(img1, img2, scale):
"""calculate ssim value"""
def ssim(img1, img2):
C1 = (0.01 * 255) ** 2
C2 = (0.03 * 255) ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1 ** 2
mu2_sq = mu2 ** 2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
border = 0
if scale != 1:
border = scale
else:
border = scale + 6
img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 256.0 + 16.0
img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 256.0 + 16.0
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1_y = img1_y[border:h - border, border:w - border]
img2_y = img2_y[border:h - border, border:w - border]
if img1_y.ndim == 2:
return ssim(img1_y, img2_y)
if img1.ndim == 3:
if img1.shape[2] == 3:
ssims = []
for _ in range(3):
ssims.append(ssim(img1, img2))
return np.array(ssims).mean()
if img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError('Wrong input image dimensions.')

View File

@ -0,0 +1,111 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""main model of wdsr"""
import mindspore
import mindspore.nn as nn
import numpy as np
class MeanShift(mindspore.nn.Conv2d):
"""add or sub means of input data"""
def __init__(
self, rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1, dtype=mindspore.float32):
std = mindspore.Tensor(rgb_std, dtype)
weight = mindspore.Tensor(np.eye(3), dtype).reshape(3, 3, 1, 1) / std.reshape(3, 1, 1, 1)
bias = sign * rgb_range * mindspore.Tensor(rgb_mean, dtype) / std
super(MeanShift, self).__init__(3, 3, kernel_size=1, has_bias=True, weight_init=weight, bias_init=bias)
for p in self.get_parameters():
p.requires_grad = False
class Block(nn.Cell):
"""residual block"""
def __init__(self):
super(Block, self).__init__()
# wn = lambda x: mindspore.nn.GroupNorm(x)
act = nn.ReLU()
self.res_scale = 1
body = []
expand = 6
linear = 0.8
body.append(nn.Conv2d(64, 64 * expand, 1, padding=(1 // 2), pad_mode='pad', has_bias=True))
body.append(act)
body.append(nn.Conv2d(64 * expand, int(64 * linear), 1, padding=1 // 2, pad_mode='pad', has_bias=True))
body.append(nn.Conv2d(int(64 * linear), 64, 3, padding=3 // 2, pad_mode='pad', has_bias=True))
self.body = nn.SequentialCell(body)
def construct(self, x):
res = self.body(x)
res += x
return res
class PixelShuffle(nn.Cell):
"""perform pixel shuffle"""
def __init__(self, upscale_factor):
super().__init__()
self.DepthToSpace = mindspore.ops.DepthToSpace(upscale_factor)
def construct(self, x):
return self.DepthToSpace(x)
class WDSR(nn.Cell):
"""main structure of wdsr"""
def __init__(self):
super(WDSR, self).__init__()
scale = 2
n_resblocks = 8
n_feats = 64
self.sub_mean = MeanShift(255)
self.add_mean = MeanShift(255, sign=1)
# define head module
head = []
head.append(
nn.Conv2d(
3, n_feats, 3,
pad_mode='pad',
padding=(3 // 2), has_bias=True))
# define body module
body = []
for _ in range(n_resblocks):
body.append(Block())
# define tail module
tail = []
out_feats = scale * scale * 3
tail.append(
nn.Conv2d(n_feats, out_feats, 3, padding=3 // 2, pad_mode='pad', has_bias=True))
self.depth_to_space = mindspore.ops.DepthToSpace(2)
skip = []
skip.append(
nn.Conv2d(3, out_feats, 5, padding=5 // 2, pad_mode='pad', has_bias=True))
self.head = nn.SequentialCell(head)
self.body = nn.SequentialCell(body)
self.tail = nn.SequentialCell(tail)
self.skip = nn.SequentialCell(skip)
def construct(self, x):
x = self.sub_mean(x) / 127.5
s = self.skip(x)
x = self.head(x)
x = self.body(x)
x = self.tail(x)
x += s
x = self.depth_to_space(x)
x = self.add_mean(x * 127.5)
return x

View File

@ -0,0 +1,87 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""wdsr train wrapper"""
import os
import time
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.train.serialization import save_checkpoint
class Trainer():
"""Trainer"""
def __init__(self, args, loader, my_model):
self.args = args
self.scale = args.scale
self.trainloader = loader
self.model = my_model
self.model.set_train()
self.criterion = nn.L1Loss()
self.loss_history = []
self.begin_time = time.time()
self.optimizer = nn.Adam(self.model.trainable_params(), learning_rate=args.lr, loss_scale=1024.0)
self.loss_net = nn.WithLossCell(self.model, self.criterion)
self.net = nn.TrainOneStepCell(self.loss_net, self.optimizer)
def train(self, epoch):
"""Trainer"""
losses = 0
batch_idx = 0
for batch_idx, imgs in enumerate(self.trainloader):
lr = imgs["LR"]
hr = imgs["HR"]
lr = Tensor(lr, mstype.float32)
hr = Tensor(hr, mstype.float32)
t1 = time.time()
loss = self.net(lr, hr)
t2 = time.time()
losses += loss.asnumpy()
print('Epoch: %g, Step: %g , loss: %f, time: %f s ' % \
(epoch, batch_idx, loss.asnumpy(), t2 - t1), end='\n', flush=True)
print("the epoch loss is", losses / (batch_idx + 1), flush=True)
self.loss_history.append(losses / (batch_idx + 1))
print(self.loss_history)
t = time.time() - self.begin_time
t = int(t)
print(", running time: %gh%g'%g''"%(t//3600, (t-t//3600*3600)//60, t%60), flush=True)
os.makedirs(self.args.save, exist_ok=True)
if self.args.rank == 0 and (epoch+1)%10 == 0:
save_checkpoint(self.net, self.args.save + "model_" + str(self.epoch) + '.ckpt')
def update_learning_rate(self, epoch):
"""Update learning rates for all the networks; called at the end of every epoch.
:param epoch: current epoch
:type epoch: int
:param lr: learning rate of cyclegan
:type lr: float
:param niter: number of epochs with the initial learning rate
:type niter: int
:param niter_decay: number of epochs to linearly decay learning rate to zero
:type niter_decay: int
"""
self.epoch = epoch
print("*********** epoch: {} **********".format(epoch))
lr = self.args.lr / (2 ** ((epoch+1)//200))
self.adjust_lr('model', self.optimizer, lr)
print("*********************************")
def adjust_lr(self, name, optimizer, lr):
"""Adjust learning rate for the corresponding model.
:param name: name of model
:type name: str
:param optimizer: the optimizer of the corresponding model
:type optimizer: torch.optim
:param lr: learning rate to be adjusted
:type lr: float
"""
lr_param = optimizer.get_lr()
lr_param.assign_value(Tensor(lr, mstype.float32))
print('==> ' + name + ' learning rate: ', lr_param.asnumpy())

View File

@ -0,0 +1,76 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""wdsr train script"""
import os
from mindspore import context
from mindspore import dataset as ds
import mindspore.nn as nn
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.common import set_seed
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from src.args import args
from src.data.div2k import DIV2K
from src.model import WDSR
def train_net():
"""train wdsr"""
set_seed(1)
device_id = int(os.getenv('DEVICE_ID', '0'))
rank_id = int(os.getenv('RANK_ID', '0'))
device_num = int(os.getenv('RANK_SIZE', '1'))
# if distribute:
if device_num > 1:
init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=device_num, gradients_mean=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
train_dataset.set_scale(args.task_id)
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num,
shard_id=rank_id, shuffle=True)
train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
net_m = WDSR()
print("Init net successfully")
if args.ckpt_path:
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(net_m, param_dict)
print("Load net weight successfully")
step_size = train_de_dataset.get_dataset_size()
lr = []
for i in range(0, args.epochs):
cur_lr = args.lr / (2 ** ((i + 1)//200))
lr.extend([cur_lr] * step_size)
opt = nn.Adam(net_m.trainable_params(), learning_rate=lr, loss_scale=1024.0)
loss = nn.L1Loss()
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=args.init_loss_scale, \
scale_factor=2, scale_window=1000)
model = Model(net_m, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager)
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
config_ck = CheckpointConfig(save_checkpoint_steps=args.ckpt_save_interval * step_size,
keep_checkpoint_max=args.ckpt_save_max)
ckpt_cb = ModelCheckpoint(prefix="wdsr", directory=args.ckpt_save_path, config=config_ck)
if device_id == 0:
cb += [ckpt_cb]
model.train(args.epochs, train_de_dataset, callbacks=cb, dataset_sink_mode=True)
if __name__ == "__main__":
train_net()