diff --git a/model_zoo/research/cv/wdsr/README_CN.md b/model_zoo/research/cv/wdsr/README_CN.md new file mode 100644 index 00000000000..5650c698156 --- /dev/null +++ b/model_zoo/research/cv/wdsr/README_CN.md @@ -0,0 +1,297 @@ +目录 + + + +- [目录](#目录) +- [WDSR描述](#WDSR描述) +- [模型架构](#模型架构) +- [数据集](#数据集) +- [环境要求](#环境要求) +- [快速入门](#快速入门) +- [脚本说明](#脚本说明) + - [脚本及样例代码](#脚本及样例代码) + - [脚本参数](#脚本参数) + - [训练过程](#训练过程) + - [训练](#训练) + - [分布式训练](#分布式训练) + - [评估过程](#评估过程) + - [评估](#评估) +- [模型导出](#模型导出) +- [模型描述](#模型描述) + - [性能](#性能) + - [训练性能](#训练性能) + - [评估性能](#评估性能) +- [随机情况说明](#随机情况说明) +- [ModelZoo主页](#modelzoo主页) + + + +# WDSR描述 + +WDSR于2018年提出的 WDSR用于提高深度超分辨率网络的精度,它在 NTIRE 2018 单幅图像超分辨率挑战赛中获得了所有三个真实赛道的第一名。 +[论文1](https://arxiv.org/abs/1707.02921):Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **"Enhanced Deep Residual Networks for Single Image Super-Resolution,"** *2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**.[论文2](https://arxiv.org/abs/1808.08718): Jiahui Yu, Yuchen Fan, Jianchao Yang, Ning Xu, Zhaowen Wang, Xinchao Wang, Thomas Huang, **"Wide Activation for Efficient and Accurate Image Super-Resolution"**, arXiv preprint arXiv:1808.08718. + +# 模型架构 + +WDSR网络主要由几个基本模块(包括卷积层和池化层)组成。通过更广泛的激活和线性低秩卷积,并引入权重归一化实现更高精度的单幅图像超分辨率。这里的基本模块主要包括以下基本操作: **1 × 1 卷积**和**3 × 3 卷积**。 +经过1次卷积层,再串联32个残差模块,再经过1次卷积层,最后上采样并卷积。 + +# 数据集 + +使用的数据集:[DIV2K]() + +- 数据集大小:7.11G, + + - 训练集:共900张图像 + - 测试集:共100张图像 + +- 数据格式:png文件 + + - 注:数据将在src/data/DIV2K.py中处理。 + + ```shell + DIV2K + ├── DIV2K_test_LR_bicubic + │   ├── X2 + │   │   ├── 0901x2.png + │ │ ├─ ... + │   │   └── 1000x2.png + │   ├── X3 + │   │   ├── 0901x3.png + │ │ ├─ ... + │   │   └── 1000x3.png + │   └── X4 + │   ├── 0901x4.png + │ ├─ ... + │   └── 1000x4.png + ├── DIV2K_test_LR_unknown + │   ├── X2 + │   │   ├── 0901x2.png + │ │ ├─ ... + │   │   └── 1000x2.png + │   ├── X3 + │   │   ├── 0901x3.png + │ │ ├─ ... + │   │   └── 1000x3.png + │   └── X4 + │   ├── 0901x4.png + │ ├─ ... + │   └── 1000x4.png + ├── DIV2K_train_HR + │   ├── 0001.png + │ ├─ ... + │   └── 0900.png + ├── DIV2K_train_LR_bicubic + │   ├── X2 + │   │   ├── 0001x2.png + │ │ ├─ ... + │   │   └── 0900x2.png + │   ├── X3 + │   │   ├── 0001x3.png + │ │ ├─ ... + │   │   └── 0900x3.png + │   └── X4 + │   ├── 0001x4.png + │ ├─ ... + │   └── 0900x4.png + └── DIV2K_train_LR_unknown + ├── X2 + │   ├── 0001x2.png + │ ├─ ... + │   └── 0900x2.png + ├── X3 + │   ├── 0001x3.png + │ ├─ ... + │   └── 0900x3.png + └── X4 + ├── 0001x4.png + ├─ ... + └── 0900x4.png + ``` + +# 环境要求 + +- 硬件(Ascend) + - 使用ascend处理器来搭建硬件环境。 +- 框架 + - [MindSpore](https://www.mindspore.cn/install/en) +- 如需查看详情,请参见如下资源: + - [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + +# 快速入门 + +通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估: + +```shell +#单卡训练 +sh run_ascend_standalone.sh [TRAIN_DATA_DIR] +``` + +```shell +#分布式训练 +sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR] +``` + +```python +#评估 +sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE] +``` + +# 脚本说明 + +## 脚本及样例代码 + +```bash +WDSR + ├── README_CN.md //自述文件 + ├── eval.py //评估脚本 + ├── export.py + ├── script + │   ├── run_ascend_distribute.sh //Ascend分布式训练shell脚本 + │   ├── run_ascend_standalone.sh //Ascend单卡训练shell脚本 + │   └── run_eval.sh //eval验证shell脚本 + ├── src + │   ├── args.py //超参数 + │   ├── common.py //公共网络模块 + │   ├── data + │   │   ├── common.py //公共数据集 + │   │   ├── div2k.py //div2k数据集 + │   │   └── srdata.py //所有数据集 + │   ├── metrics.py //PSNR和SSIM计算器 + │   ├── model.py //WDSR网络 + │   └── utils.py //训练脚本 + └── train.py //训练脚本 +``` + +## 脚本参数 + +主要参数如下: + +```python + -h, --help show this help message and exit + --dir_data DIR_DATA dataset directory + --data_train DATA_TRAIN + train dataset name + --data_test DATA_TEST + test dataset name + --data_range DATA_RANGE + train/test data range + --ext EXT dataset file extension + --scale SCALE super resolution scale + --patch_size PATCH_SIZE + output patch size + --rgb_range RGB_RANGE + maximum value of RGB + --n_colors N_COLORS number of color channels to use + --no_augment do not use data augmentation + --model MODEL model name + --n_resblocks N_RESBLOCKS + number of residual blocks + --n_feats N_FEATS number of feature maps + --res_scale RES_SCALE + residual scaling + --test_every TEST_EVERY + do test per every N batches + --epochs EPOCHS number of epochs to train + --batch_size BATCH_SIZE + input batch size for training + --test_only set this option to test the model + --lr LR learning rate + --ckpt_save_path CKPT_SAVE_PATH + path to save ckpt + --ckpt_save_interval CKPT_SAVE_INTERVAL + save ckpt frequency, unit is epoch + --ckpt_save_max CKPT_SAVE_MAX + max number of saved ckpt + --ckpt_path CKPT_PATH + path of saved ckpt + --task_id TASK_ID + +``` + +## 训练过程 + +### 训练 + +- Ascend处理器环境运行 + + ```bash + sh run_ascend_standalone.sh [TRAIN_DATA_DIR] + ``` + + 上述python命令将在后台运行,您可以通过train.log文件查看结果。 + +### 分布式训练 + +- Ascend处理器环境运行 + + ```bash + sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR] + ``` + +TRAIN_DATA_DIR = "~DATA/"。 + +## 评估过程 + +### 评估 + +在运行以下命令之前,请检查用于评估的检查点路径。 + +```bash +sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] DIV2K +``` + +TEST_DATA_DIR = "~DATA/"。 + +您可以通过eval.log文件查看结果。 + +# 模型导出 + +```bash +python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT] +``` + +FILE_FORMAT 可选 ['MINDIR', 'AIR', 'ONNX'], 默认['MINDIR']。 + +# 模型描述 + +## 性能 + +### 训练性能 + +| 参数 | Ascend | +| ------------- | ------------------------------------------------------------ | +| 资源 | Ascend 910 | +| 上传日期 | 2021-7-4 | +| MindSpore版本 | 1.2.0 | +| 数据集 | DIV2K | +| 训练参数 | epoch=1000, steps=100, batch_size =16, lr=0.0001 | +| 优化器 | Adam | +| 损失函数 | L1 | +| 输出 | 超分辨率图片 | +| 损失 | 3.5 | +| 速度 | 8卡:约130毫秒/步 | +| 总时长 | 8卡:0.5小时 | +| 微调检查点 | 35 MB(.ckpt文件) | +| 脚本 | [WDSR](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/wdsr) | + +### 评估性能 + +| 参数 | Ascend | +| ------------- | ----------------------------------------------------------- | +| 资源 | Ascend 910 | +| 上传日期 | 2021-7-4 | +| MindSpore版本 | 1.2.0 | +| 数据集 | DIV2K | +| batch_size | 1 | +| 输出 | 超分辨率图片 | +| PSNR | DIV2K 34.7780 | + +# 随机情况说明 + +在train.py中,我们设置了“train_net”函数内的种子。 + +# ModelZoo主页 + + 请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。 diff --git a/model_zoo/research/cv/wdsr/eval.py b/model_zoo/research/cv/wdsr/eval.py new file mode 100644 index 00000000000..97a785fc6c0 --- /dev/null +++ b/model_zoo/research/cv/wdsr/eval.py @@ -0,0 +1,74 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""wdsr eval script""" +import os +import numpy as np +import mindspore.dataset as ds +from mindspore import Tensor, context +from mindspore.common import dtype as mstype +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from src.args import args +import src.model as wdsr +from src.data.srdata import SRData +from src.data.div2k import DIV2K +from src.metrics import calc_psnr, quantize, calc_ssim +device_id = int(os.getenv('DEVICE_ID', '0')) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) +context.set_context(max_call_depth=10000) +def eval_net(): + """eval""" + if args.epochs == 0: + args.epochs = 1e8 + for arg in vars(args): + if vars(args)[arg] == 'True': + vars(args)[arg] = True + elif vars(args)[arg] == 'False': + vars(args)[arg] = False + if args.data_test[0] == 'DIV2K': + train_dataset = DIV2K(args, name=args.data_test, train=False, benchmark=False) + else: + train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False) + train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False) + train_de_dataset = train_de_dataset.batch(1, drop_remainder=True) + train_loader = train_de_dataset.create_dict_iterator(output_numpy=True) + net_m = wdsr.WDSR() + if args.ckpt_path: + param_dict = load_checkpoint(args.ckpt_path) + load_param_into_net(net_m, param_dict) + net_m.set_train(False) + print('load mindspore net successfully.') + num_imgs = train_de_dataset.get_dataset_size() + psnrs = np.zeros((num_imgs, 1)) + ssims = np.zeros((num_imgs, 1)) + for batch_idx, imgs in enumerate(train_loader): + lr = imgs['LR'] + hr = imgs['HR'] + lr = Tensor(lr, mstype.float32) + pred = net_m(lr) + pred_np = pred.asnumpy() + pred_np = quantize(pred_np, 255) + psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0) + pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0) + hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0) + ssim = calc_ssim(pred_np, hr, args.scale[0]) + print("current psnr: ", psnr) + print("current ssim: ", ssim) + psnrs[batch_idx, 0] = psnr + ssims[batch_idx, 0] = ssim + print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0])) + print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0])) +if __name__ == '__main__': + print("Start eval function!") + eval_net() diff --git a/model_zoo/research/cv/wdsr/export.py b/model_zoo/research/cv/wdsr/export.py new file mode 100644 index 00000000000..6678ae07341 --- /dev/null +++ b/model_zoo/research/cv/wdsr/export.py @@ -0,0 +1,43 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""export net together with checkpoint into air/mindir/onnx models""" +import os +import argparse +import numpy as np +from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export +import src.model as wdsr + +parser = argparse.ArgumentParser(description='wdsr export') +parser.add_argument("--batch_size", type=int, default=1, help="batch size") +parser.add_argument("--ckpt_path", type=str, required=True, help="path of checkpoint file") +parser.add_argument("--file_name", type=str, default="wdsr", help="output file name.") +parser.add_argument("--file_format", type=str, default="MINDIR", choices=['MINDIR', 'AIR', 'ONNX'], help="file format") +args = parser.parse_args() + +def run_export(args1): + """run_export""" + device_id = int(os.getenv("DEVICE_ID", '0')) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) + net = wdsr.WDSR() + param_dict = load_checkpoint(args1.ckpt_path) + load_param_into_net(net, param_dict) + net.set_train(False) + print('load mindspore net and checkpoint successfully.') + inputs = Tensor(np.zeros([args1.batch_size, 3, 678, 1020], np.float32)) + export(net, inputs, file_name=args.file_name, file_format=args.file_format) + print('export successfully!') + +if __name__ == "__main__": + run_export(args) diff --git a/model_zoo/research/cv/wdsr/script/run_ascend_distribute.sh b/model_zoo/research/cv/wdsr/script/run_ascend_distribute.sh new file mode 100644 index 00000000000..de20743d5b8 --- /dev/null +++ b/model_zoo/research/cv/wdsr/script/run_ascend_distribute.sh @@ -0,0 +1,69 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ]; then + echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) + +if [ ! -f $PATH1 ]; then + echo "error: RANK_TABLE_FILE=$PATH1 is not a file" + exit 1 +fi + +if [ ! -d $PATH2 ]; then + echo "error: TRAIN_DATA_DIR=$PATH2 is not a directory" + exit 1 +fi + +export DEVICE_NUM=8 +export RANK_SIZE=8 +export RANK_TABLE_FILE=$PATH1 + +for ((i = 0; i < ${DEVICE_NUM}; i++)); do + export DEVICE_ID=$i + export RANK_ID=$i + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp ../*.py ./train_parallel$i + cp *.sh ./train_parallel$i + cp -r ../src ./train_parallel$i + cd ./train_parallel$i || exit + echo "start training for rank $RANK_ID, device $DEVICE_ID" + env >env.log + + nohup python train.py \ + --batch_size 16 \ + --lr 1e-4 \ + --scale 2 \ + --task_id 0 \ + --dir_data $PATH2 \ + --epochs 1000 \ + --test_every 100 \ + --patch_size 48 > train.log 2>&1 & + cd .. +done diff --git a/model_zoo/research/cv/wdsr/script/run_ascend_standalone.sh b/model_zoo/research/cv/wdsr/script/run_ascend_standalone.sh new file mode 100644 index 00000000000..eb62b5d2268 --- /dev/null +++ b/model_zoo/research/cv/wdsr/script/run_ascend_standalone.sh @@ -0,0 +1,72 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +#!/bin/bash +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 1 ]; then + echo "Usage: sh run_standalone_train.sh [TRAIN_DATA_DIR]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) + +if [ ! -d $PATH1 ]; then + echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory" + exit 1 +fi + + +if [ -d "train" ]; then + rm -rf ./train +fi +mkdir ./train +cp ../*.py ./train +cp -r ../src ./train +cd ./train || exit + +env >env.log + +nohup python train.py \ + --batch_size 16 \ + --lr 1e-4 \ + --scale 2 \ + --task_id 0 \ + --dir_data $PATH1 \ + --epochs 1000 \ + --test_every 100 \ + --patch_size 48 > train.log 2>&1 & diff --git a/model_zoo/research/cv/wdsr/script/run_eval.sh b/model_zoo/research/cv/wdsr/script/run_eval.sh new file mode 100644 index 00000000000..f8828810f9e --- /dev/null +++ b/model_zoo/research/cv/wdsr/script/run_eval.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ]; then + echo "Usage: sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) +DATASET_TYPE=$3 + +if [ ! -d $PATH1 ]; then + echo "error: TEST_DATA_DIR=$PATH1 is not a directory" + exit 1 +fi + +if [ ! -f $PATH2 ]; then + echo "error: CHECKPOINT_PATH=$PATH2 is not a file" + exit 1 +fi + +if [ -d "eval" ]; then + rm -rf ./eval +fi +mkdir ./eval +cp ../*.py ./eval +cp -r ../src ./eval +cd ./eval || exit +env >env.log +echo "start evaluation ..." + +python eval.py \ + --dir_data=${PATH1} \ + --test_only \ + --ext "img" \ + --data_test=${DATASET_TYPE} \ + --ckpt_path=${PATH2} \ + --task_id 0 \ + --scale 2 > eval.log 2>&1 & diff --git a/model_zoo/research/cv/wdsr/src/args.py b/model_zoo/research/cv/wdsr/src/args.py new file mode 100644 index 00000000000..230a6d6160b --- /dev/null +++ b/model_zoo/research/cv/wdsr/src/args.py @@ -0,0 +1,102 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""args parser""" +import argparse +parser = argparse.ArgumentParser(description='EDSR and MDSR') +# Data specifications +parser.add_argument('--dir_data', type=str, default='/cache/data/', + help='dataset directory') +parser.add_argument('--data_train', type=str, default='DIV2K', + help='train dataset name') +parser.add_argument('--data_test', type=str, default='DIV2K', + help='test dataset name') +parser.add_argument('--data_range', type=str, default='1-800/801-900', + help='train/test data range') +parser.add_argument('--ext', type=str, default='sep', + help='dataset file extension') +parser.add_argument('--scale', type=str, default='4', + help='super resolution scale') +parser.add_argument('--patch_size', type=int, default=48, + help='output patch size') +parser.add_argument('--rgb_range', type=int, default=255, + help='maximum value of RGB') +parser.add_argument('--n_colors', type=int, default=3, + help='number of color channels to use') +parser.add_argument('--no_augment', action='store_true', + help='do not use data augmentation') +# Model specifications +parser.add_argument('--model', default='WDSR', + help='model name') +parser.add_argument('--n_resblocks', type=int, default=16, + help='number of residual blocks') +parser.add_argument('--n_feats', type=int, default=64, + help='number of feature maps') +parser.add_argument('--res_scale', type=float, default=1, + help='residual scaling') +# Training specifications +parser.add_argument('--test_every', type=int, default=1000, + help='do test per every N batches') +parser.add_argument('--epochs', type=int, default=300, + help='number of epochs to train') +parser.add_argument('--batch_size', type=int, default=16, + help='input batch size for training') +#parser.add_argument('--self_ensemble', action='store_true', +# help='use self-ensemble method for test') +parser.add_argument('--test_only', action='store_true', + help='set this option to test the model') +# Optimization specifications +parser.add_argument('--lr', type=float, default=1e-4, + help='learning rate') +parser.add_argument('--init_loss_scale', type=float, default=65536., + help='scaling factor') +parser.add_argument('--decay', type=str, default='200', + help='learning rate decay type') +parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), + help='ADAM beta') +parser.add_argument('--epsilon', type=float, default=1e-8, + help='ADAM epsilon for numerical stability') +parser.add_argument('--weight_decay', type=float, default=0, + help='weight decay') +parser.add_argument('--gclip', type=float, default=0, + help='gradient clipping threshold (0 = no clipping)') +# ckpt specifications +parser.add_argument('--ckpt_save_path', type=str, default='./ckpt/', + help='path to save ckpt') +parser.add_argument('--ckpt_save_interval', type=int, default=10, + help='save ckpt frequency, unit is epoch') +parser.add_argument('--ckpt_save_max', type=int, default=5, + help='max number of saved ckpt') +parser.add_argument('--ckpt_path', type=str, default='', + help='path of saved ckpt') +# alltask +parser.add_argument('--task_id', type=int, default=0) +# rgb_mean +parser.add_argument('--r_mean', type=float, default=0.4488, + help='r_mean') +parser.add_argument('--g_mean', type=float, default=0.4371, + help='g_mean') +parser.add_argument('--b_mean', type=float, default=0.4040, + help='b_mean') +args, unparsed = parser.parse_known_args() +args.scale = [int(x) for x in args.scale.split("+")] +args.data_train = args.data_train.split('+') +args.data_test = args.data_test.split('+') +if args.epochs == 0: + args.epochs = 1e4 +for arg in vars(args): + if vars(args)[arg] == 'True': + vars(args)[arg] = True + elif vars(args)[arg] == 'False': + vars(args)[arg] = False diff --git a/model_zoo/research/cv/wdsr/src/data/common.py b/model_zoo/research/cv/wdsr/src/data/common.py new file mode 100644 index 00000000000..370fc37fa59 --- /dev/null +++ b/model_zoo/research/cv/wdsr/src/data/common.py @@ -0,0 +1,78 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""common""" +import random +import numpy as np +def get_patch(*args, patch_size=96, scale=2, input_large=False): + """common""" + ih, iw = args[0].shape[:2] + tp = patch_size + ip = tp // scale + # print("tp=%g,scale=%g,"%(tp,scale),end='',flush=True) + # print("ih=%g,iw=%g,ip=%g"%(ih,iw,ip),flush=True) + ix = random.randrange(0, iw - ip + 1) + iy = random.randrange(0, ih - ip + 1) + # ix = iy = 0 + if not input_large: + tx, ty = scale * ix, scale * iy + else: + tx, ty = ix, iy + ret = [args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]] + return ret +def set_channel(*args, n_channels=3): + """common""" + def _set_channel(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + c = img.shape[2] + if n_channels == 3 and c == 1: + img = np.concatenate([img] * n_channels, 2) + return img[:, :, :n_channels] + return [_set_channel(a) for a in args] +def np2Tensor(*args, rgb_range=255): + def _np2Tensor(img): + np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) + input_data = np_transpose.astype(np.float32) + output = input_data * (rgb_range / 255) + return output + return [_np2Tensor(a) for a in args] +def augment(*args, hflip=True, rot=True): + """common""" + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + def _augment(img): + """common""" + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + return [_augment(a) for a in args] +def search(root, target="JPEG"): + """srdata""" + item_list = [] + items = os.listdir(root) + for item in items: + path = os.path.join(root, item) + if os.path.isdir(path): + item_list.extend(search(path, target)) + elif path.split('/')[-1].startswith(target): + item_list.append(path) + elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]): + item_list.append(path) + return item_list diff --git a/model_zoo/research/cv/wdsr/src/data/div2k.py b/model_zoo/research/cv/wdsr/src/data/div2k.py new file mode 100644 index 00000000000..7a9bfde8b8d --- /dev/null +++ b/model_zoo/research/cv/wdsr/src/data/div2k.py @@ -0,0 +1,43 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""div2k""" +import os +from src.data.srdata import SRData +class DIV2K(SRData): + """DIV2K""" + def __init__(self, args, name='DIV2K', train=True, benchmark=False): + data_range = [r.split('-') for r in args.data_range.split('/')] + if train: + data_range = data_range[0] + else: + if args.test_only and len(data_range) == 1: + data_range = data_range[0] + else: + data_range = data_range[1] + self.begin, self.end = list(map(int, data_range)) + super(DIV2K, self).__init__(args, name=name, train=train, benchmark=benchmark) + self.dir_hr = None + self.dir_lr = None + + def _scan(self): + names_hr, names_lr = super(DIV2K, self)._scan() + names_hr = names_hr[self.begin - 1:self.end] + names_lr = [n[self.begin - 1:self.end] for n in names_lr] + return names_hr, names_lr + + def _set_filesystem(self, dir_data): + super(DIV2K, self)._set_filesystem(dir_data) + self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') + self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') diff --git a/model_zoo/research/cv/wdsr/src/data/srdata.py b/model_zoo/research/cv/wdsr/src/data/srdata.py new file mode 100644 index 00000000000..e9bba1b01ae --- /dev/null +++ b/model_zoo/research/cv/wdsr/src/data/srdata.py @@ -0,0 +1,215 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""""srdata""" +import os +import glob +import random +import pickle +import imageio +from src.data import common +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True +class SRData: + """srdata""" + def __init__(self, args, name='', train=True, benchmark=False): + self.args = args + self.name = name + self.train = train + self.split = 'train' if train else 'test' + self.do_eval = True + self.benchmark = benchmark + self.input_large = (args.model == 'VDSR') + self.scale = args.scale + self.idx_scale = 0 + self._set_filesystem(args.dir_data) + self._set_img(args) + if train: + self._repeat(args) + + def _set_img(self, args): + """srdata""" + if args.ext.find('img') < 0: + path_bin = os.path.join(self.apath, 'bin') + os.makedirs(path_bin, exist_ok=True) + list_hr, list_lr = self._scan() + if args.ext.find('img') >= 0 or self.benchmark: + self.images_hr, self.images_lr = list_hr, list_lr + elif args.ext.find('sep') >= 0: + os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True) + for s in self.scale: + if s == 1: + os.makedirs(os.path.join(self.dir_hr), exist_ok=True) + else: + os.makedirs( + os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True) + self.images_hr, self.images_lr = [], [[] for _ in self.scale] + for h in list_hr: + b = h.replace(self.apath, path_bin) + b = b.replace(self.ext[0], '.pt') + self.images_hr.append(b) + self._check_and_load(args.ext, h, b, verbose=True) + for i, ll in enumerate(list_lr): + for l in ll: + b = l.replace(self.apath, path_bin) + b = b.replace(self.ext[1], '.pt') + self.images_lr[i].append(b) + self._check_and_load(args.ext, l, b, verbose=True) + + def _repeat(self, args): + """srdata""" + n_patches = args.batch_size * args.test_every + n_images = len(args.data_train) * len(self.images_hr) + if n_images == 0: + self.repeat = 0 + else: + self.repeat = max(n_patches // n_images, 1) + + def _scan(self): + """srdata""" + names_hr = sorted( + glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))) + names_lr = [[] for _ in self.scale] + for f in names_hr: + filename, _ = os.path.splitext(os.path.basename(f)) + for si, s in enumerate(self.scale): + if s != 1: + scale = s + names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \ + .format(s, filename, scale, self.ext[1]))) + for si, s in enumerate(self.scale): + if s == 1: + names_lr[si] = names_hr + return names_hr, names_lr + + def _set_filesystem(self, dir_data): + self.apath = os.path.join(dir_data, self.name[0]) + self.dir_hr = os.path.join(self.apath, 'HR') + self.dir_lr = os.path.join(self.apath, 'LR_bicubic') + self.ext = ('.png', '.png') + + def _check_and_load(self, ext, img, f, verbose=True): + if not os.path.isfile(f) or ext.find('reset') >= 0: + if verbose: + print('Making a binary: {}'.format(f)) + with open(f, 'wb') as _f: + pickle.dump(imageio.imread(img), _f) + + # pylint: disable=unused-variable + def __getitem__(self, idx): + lr, hr, filename = self._load_file(idx) + pair = self.get_patch(lr, hr) + pair = common.set_channel(*pair, n_channels=self.args.n_colors) + pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) + #return pair_t[0], pair_t[1], [self.idx_scale], [filename] + return pair_t[0], pair_t[1] + + def __len__(self): + if self.train: + return len(self.images_hr) * self.repeat + return len(self.images_hr) + + def _get_index(self, idx): + if self.train: + return idx % len(self.images_hr) + return idx + + def _load_file_deblur(self, idx, train=True): + """srdata""" + idx = self._get_index(idx) + if train: + f_hr = self.images_hr[idx] + f_lr = self.images_lr[idx] + else: + f_hr = self.deblur_hr_test[idx] + f_lr = self.deblur_lr_test[idx] + filename, _ = os.path.splitext(os.path.basename(f_hr)) + filename = f_hr[-27:-17] + filename + hr = imageio.imread(f_hr) + lr = imageio.imread(f_lr) + return lr, hr, filename + + def _load_file_hr(self, idx): + """srdata""" + idx = self._get_index(idx) + f_hr = self.images_hr[idx] + filename, _ = os.path.splitext(os.path.basename(f_hr)) + if self.args.ext == 'img' or self.benchmark: + hr = imageio.imread(f_hr) + elif self.args.ext.find('sep') >= 0: + with open(f_hr, 'rb') as _f: + hr = pickle.load(_f) + return hr, filename + + def _load_rain_test(self, idx): + f_hr = self.derain_hr_test[idx] + f_lr = self.derain_lr_test[idx] + filename, _ = os.path.splitext(os.path.basename(f_lr)) + norain = imageio.imread(f_hr) + rain = imageio.imread(f_lr) + return norain, rain, filename + + def _load_file(self, idx): + """srdata""" + idx = self._get_index(idx) + # print(idx,flush=True) + f_hr = self.images_hr[idx] + f_lr = self.images_lr[self.idx_scale][idx] + filename, _ = os.path.splitext(os.path.basename(f_hr)) + if self.args.ext == 'img' or self.benchmark: + hr = imageio.imread(f_hr) + lr = imageio.imread(f_lr) + elif self.args.ext.find('sep') >= 0: + with open(f_hr, 'rb') as _f: + hr = pickle.load(_f) + with open(f_lr, 'rb') as _f: + lr = pickle.load(_f) + return lr, hr, filename + + def get_patch_hr(self, hr): + """srdata""" + if self.train: + hr = self.get_patch_img_hr(hr, patch_size=self.args.patch_size, scale=1) + return hr + + def get_patch_img_hr(self, img, patch_size=96, scale=2): + """srdata""" + ih, iw = img.shape[:2] + tp = patch_size + ip = tp // scale + ix = random.randrange(0, iw - ip + 1) + iy = random.randrange(0, ih - ip + 1) + ret = img[iy:iy + ip, ix:ix + ip, :] + return ret + + def get_patch(self, lr, hr): + """srdata""" + scale = self.scale[self.idx_scale] + if self.train: + lr, hr = common.get_patch( + lr, hr, + patch_size=self.args.patch_size * scale, + scale=scale) + if not self.args.no_augment: + lr, hr = common.augment(lr, hr) + else: + ih, iw = lr.shape[:2] + hr = hr[0:ih * scale, 0:iw * scale] + return lr, hr + + def set_scale(self, idx_scale): + if not self.input_large: + self.idx_scale = idx_scale + else: + self.idx_scale = random.randint(0, len(self.scale) - 1) diff --git a/model_zoo/research/cv/wdsr/src/metrics.py b/model_zoo/research/cv/wdsr/src/metrics.py new file mode 100644 index 00000000000..5fc097fcc89 --- /dev/null +++ b/model_zoo/research/cv/wdsr/src/metrics.py @@ -0,0 +1,102 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""metrics""" +import math +import numpy as np +import cv2 + + +def quantize(img, rgb_range): + """quantize image range to 0-255""" + pixel_range = 255 / rgb_range + img = np.multiply(img, pixel_range) + img = np.clip(img, 0, 255) + img = np.round(img) / pixel_range + return img + + +def calc_psnr(sr, hr, scale, rgb_range): + """calculate psnr""" + hr = np.float32(hr) + sr = np.float32(sr) + diff = (sr - hr) / rgb_range + gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256 + diff = np.multiply(diff, gray_coeffs).sum(1) + if hr.size == 1: + return 0 + if scale != 1: + shave = scale + else: + shave = scale + 6 + if scale == 1: + valid = diff + else: + valid = diff[..., shave:-shave, shave:-shave] + mse = np.mean(pow(valid, 2)) + return -10 * math.log10(mse) + + +def rgb2ycbcr(img, y_only=True): + """from rgb space to ycbcr space""" + img.astype(np.float32) + if y_only: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + return rlt + + +def calc_ssim(img1, img2, scale): + """calculate ssim value""" + def ssim(img1, img2): + C1 = (0.01 * 255) ** 2 + C2 = (0.03 * 255) ** 2 + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1 ** 2 + mu2_sq = mu2 ** 2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + border = 0 + if scale != 1: + border = scale + else: + border = scale + 6 + img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 256.0 + 16.0 + img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 256.0 + 16.0 + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1_y = img1_y[border:h - border, border:w - border] + img2_y = img2_y[border:h - border, border:w - border] + if img1_y.ndim == 2: + return ssim(img1_y, img2_y) + if img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for _ in range(3): + ssims.append(ssim(img1, img2)) + return np.array(ssims).mean() + if img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') diff --git a/model_zoo/research/cv/wdsr/src/model.py b/model_zoo/research/cv/wdsr/src/model.py new file mode 100644 index 00000000000..b9e1e50a9a7 --- /dev/null +++ b/model_zoo/research/cv/wdsr/src/model.py @@ -0,0 +1,111 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""main model of wdsr""" + +import mindspore +import mindspore.nn as nn +import numpy as np + + +class MeanShift(mindspore.nn.Conv2d): + """add or sub means of input data""" + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1, dtype=mindspore.float32): + std = mindspore.Tensor(rgb_std, dtype) + weight = mindspore.Tensor(np.eye(3), dtype).reshape(3, 3, 1, 1) / std.reshape(3, 1, 1, 1) + bias = sign * rgb_range * mindspore.Tensor(rgb_mean, dtype) / std + super(MeanShift, self).__init__(3, 3, kernel_size=1, has_bias=True, weight_init=weight, bias_init=bias) + for p in self.get_parameters(): + p.requires_grad = False + + +class Block(nn.Cell): + """residual block""" + def __init__(self): + super(Block, self).__init__() + # wn = lambda x: mindspore.nn.GroupNorm(x) + act = nn.ReLU() + self.res_scale = 1 + body = [] + expand = 6 + linear = 0.8 + body.append(nn.Conv2d(64, 64 * expand, 1, padding=(1 // 2), pad_mode='pad', has_bias=True)) + body.append(act) + body.append(nn.Conv2d(64 * expand, int(64 * linear), 1, padding=1 // 2, pad_mode='pad', has_bias=True)) + body.append(nn.Conv2d(int(64 * linear), 64, 3, padding=3 // 2, pad_mode='pad', has_bias=True)) + self.body = nn.SequentialCell(body) + + def construct(self, x): + res = self.body(x) + res += x + return res + + +class PixelShuffle(nn.Cell): + """perform pixel shuffle""" + def __init__(self, upscale_factor): + super().__init__() + self.DepthToSpace = mindspore.ops.DepthToSpace(upscale_factor) + + def construct(self, x): + return self.DepthToSpace(x) + + +class WDSR(nn.Cell): + """main structure of wdsr""" + def __init__(self): + super(WDSR, self).__init__() + scale = 2 + n_resblocks = 8 + n_feats = 64 + self.sub_mean = MeanShift(255) + self.add_mean = MeanShift(255, sign=1) + # define head module + head = [] + head.append( + nn.Conv2d( + 3, n_feats, 3, + pad_mode='pad', + padding=(3 // 2), has_bias=True)) + + # define body module + body = [] + for _ in range(n_resblocks): + body.append(Block()) + # define tail module + tail = [] + out_feats = scale * scale * 3 + tail.append( + nn.Conv2d(n_feats, out_feats, 3, padding=3 // 2, pad_mode='pad', has_bias=True)) + self.depth_to_space = mindspore.ops.DepthToSpace(2) + skip = [] + skip.append( + nn.Conv2d(3, out_feats, 5, padding=5 // 2, pad_mode='pad', has_bias=True)) + self.head = nn.SequentialCell(head) + self.body = nn.SequentialCell(body) + self.tail = nn.SequentialCell(tail) + self.skip = nn.SequentialCell(skip) + + def construct(self, x): + x = self.sub_mean(x) / 127.5 + s = self.skip(x) + x = self.head(x) + x = self.body(x) + x = self.tail(x) + x += s + x = self.depth_to_space(x) + x = self.add_mean(x * 127.5) + return x diff --git a/model_zoo/research/cv/wdsr/src/utils.py b/model_zoo/research/cv/wdsr/src/utils.py new file mode 100644 index 00000000000..0b9fe64034c --- /dev/null +++ b/model_zoo/research/cv/wdsr/src/utils.py @@ -0,0 +1,87 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""wdsr train wrapper""" +import os +import time +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.train.serialization import save_checkpoint +class Trainer(): + """Trainer""" + def __init__(self, args, loader, my_model): + self.args = args + self.scale = args.scale + self.trainloader = loader + self.model = my_model + self.model.set_train() + self.criterion = nn.L1Loss() + self.loss_history = [] + self.begin_time = time.time() + self.optimizer = nn.Adam(self.model.trainable_params(), learning_rate=args.lr, loss_scale=1024.0) + self.loss_net = nn.WithLossCell(self.model, self.criterion) + self.net = nn.TrainOneStepCell(self.loss_net, self.optimizer) + def train(self, epoch): + """Trainer""" + losses = 0 + batch_idx = 0 + for batch_idx, imgs in enumerate(self.trainloader): + lr = imgs["LR"] + hr = imgs["HR"] + lr = Tensor(lr, mstype.float32) + hr = Tensor(hr, mstype.float32) + t1 = time.time() + loss = self.net(lr, hr) + t2 = time.time() + losses += loss.asnumpy() + print('Epoch: %g, Step: %g , loss: %f, time: %f s ' % \ + (epoch, batch_idx, loss.asnumpy(), t2 - t1), end='\n', flush=True) + print("the epoch loss is", losses / (batch_idx + 1), flush=True) + self.loss_history.append(losses / (batch_idx + 1)) + print(self.loss_history) + t = time.time() - self.begin_time + t = int(t) + print(", running time: %gh%g'%g''"%(t//3600, (t-t//3600*3600)//60, t%60), flush=True) + os.makedirs(self.args.save, exist_ok=True) + if self.args.rank == 0 and (epoch+1)%10 == 0: + save_checkpoint(self.net, self.args.save + "model_" + str(self.epoch) + '.ckpt') + def update_learning_rate(self, epoch): + """Update learning rates for all the networks; called at the end of every epoch. + :param epoch: current epoch + :type epoch: int + :param lr: learning rate of cyclegan + :type lr: float + :param niter: number of epochs with the initial learning rate + :type niter: int + :param niter_decay: number of epochs to linearly decay learning rate to zero + :type niter_decay: int + """ + self.epoch = epoch + print("*********** epoch: {} **********".format(epoch)) + lr = self.args.lr / (2 ** ((epoch+1)//200)) + self.adjust_lr('model', self.optimizer, lr) + print("*********************************") + def adjust_lr(self, name, optimizer, lr): + """Adjust learning rate for the corresponding model. + :param name: name of model + :type name: str + :param optimizer: the optimizer of the corresponding model + :type optimizer: torch.optim + :param lr: learning rate to be adjusted + :type lr: float + """ + lr_param = optimizer.get_lr() + lr_param.assign_value(Tensor(lr, mstype.float32)) + print('==> ' + name + ' learning rate: ', lr_param.asnumpy()) diff --git a/model_zoo/research/cv/wdsr/train.py b/model_zoo/research/cv/wdsr/train.py new file mode 100644 index 00000000000..6c8c8bfafbc --- /dev/null +++ b/model_zoo/research/cv/wdsr/train.py @@ -0,0 +1,76 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""wdsr train script""" +import os +from mindspore import context +from mindspore import dataset as ds +import mindspore.nn as nn +from mindspore.context import ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.communication.management import init +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.common import set_seed +from mindspore.train.model import Model +from mindspore.train.loss_scale_manager import DynamicLossScaleManager +from src.args import args +from src.data.div2k import DIV2K +from src.model import WDSR + +def train_net(): + """train wdsr""" + set_seed(1) + device_id = int(os.getenv('DEVICE_ID', '0')) + rank_id = int(os.getenv('RANK_ID', '0')) + device_num = int(os.getenv('RANK_SIZE', '1')) + # if distribute: + if device_num > 1: + init() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, + device_num=device_num, gradients_mean=True) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) + train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False) + train_dataset.set_scale(args.task_id) + train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num, + shard_id=rank_id, shuffle=True) + train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True) + net_m = WDSR() + print("Init net successfully") + if args.ckpt_path: + param_dict = load_checkpoint(args.ckpt_path) + load_param_into_net(net_m, param_dict) + print("Load net weight successfully") + step_size = train_de_dataset.get_dataset_size() + lr = [] + for i in range(0, args.epochs): + cur_lr = args.lr / (2 ** ((i + 1)//200)) + lr.extend([cur_lr] * step_size) + opt = nn.Adam(net_m.trainable_params(), learning_rate=lr, loss_scale=1024.0) + loss = nn.L1Loss() + loss_scale_manager = DynamicLossScaleManager(init_loss_scale=args.init_loss_scale, \ + scale_factor=2, scale_window=1000) + model = Model(net_m, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager) + time_cb = TimeMonitor(data_size=step_size) + loss_cb = LossMonitor() + cb = [time_cb, loss_cb] + config_ck = CheckpointConfig(save_checkpoint_steps=args.ckpt_save_interval * step_size, + keep_checkpoint_max=args.ckpt_save_max) + ckpt_cb = ModelCheckpoint(prefix="wdsr", directory=args.ckpt_save_path, config=config_ck) + if device_id == 0: + cb += [ckpt_cb] + model.train(args.epochs, train_de_dataset, callbacks=cb, dataset_sink_mode=True) + + +if __name__ == "__main__": + train_net()