!22393 add a network EDSR

Merge pull request !22393 from 李元龙/master
This commit is contained in:
i-robot 2021-09-01 14:17:52 +00:00 committed by Gitee
commit 5a851daf2f
37 changed files with 3128 additions and 1295 deletions

View File

@ -0,0 +1,126 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
device_target: "Ascend"
# ==============================================================================
# train options
amp_level: "O3"
loss_scale: 1000.0 # for ['O2', 'O3', 'auto']
keep_checkpoint_max: 60
save_epoch_frq: 100
ckpt_save_dir: "./ckpt/"
epoch_size: 6000
# eval options
eval_epoch_frq: 20
self_ensemble: True
save_sr: True
# Adam opt options
opt_type: Adam
weight_decay: 0.0
# learning rate options
learning_rate: 0.0001
milestones: [4000]
gamma: 0.5
# dataset options
dataset_name: "DIV2K"
lr_type: "bicubic"
batch_size: 2
patch_size: 192
scale: 4
dataset_sink_mode: True
need_unzip_in_modelarts: False
need_unzip_files:
- "DIV2K_train_HR.zip"
- "DIV2K_train_LR_bicubic_X2.zip"
- "DIV2K_train_LR_bicubic_X3.zip"
- "DIV2K_train_LR_bicubic_X4.zip"
- "DIV2K_train_LR_unknown_X2.zip"
- "DIV2K_train_LR_unknown_X3.zip"
- "DIV2K_train_LR_unknown_X4.zip"
- "DIV2K_valid_HR.zip"
- "DIV2K_valid_LR_bicubic_X2.zip"
- "DIV2K_valid_LR_bicubic_X3.zip"
- "DIV2K_valid_LR_bicubic_X4.zip"
- "DIV2K_valid_LR_unknown_X2.zip"
- "DIV2K_valid_LR_unknown_X3.zip"
- "DIV2K_valid_LR_unknown_X4.zip"
# net options
pre_trained: ""
rgb_range: 255
rgb_mean: [0.4488, 0.4371, 0.4040]
rgb_std: [1.0, 1.0, 1.0]
n_colors: 3
n_feats: 256
kernel_size: 3
n_resblocks: 32
res_scale: 0.1
---
# helper
enable_modelarts: "set True if run in modelarts, default: False"
# Url for modelarts
data_url: "modelarts data path"
train_url: "modelarts code path"
checkpoint_url: "modelarts checkpoint save path"
# Path for local
data_path: "local data path, data will be download from 'data_url', default: /cache/data"
output_path: "local output path, checkpoint will be upload to 'checkpoint_url', default: /cache/train"
device_target: "choice from ['Ascend'], default: Ascend"
# ==============================================================================
# train options
amp_level: "choice from ['O0', 'O2', 'O3', 'auto'], default: O3"
loss_scale: "loss scale will be used except 'O0', default: 1000.0"
keep_checkpoint_max: "max number of checkpoints to be saved, defalue: 60"
save_epoch_frq: "frequency to save checkpoint, defalue: 100"
ckpt_save_dir: "the relative path to save checkpoint, root path is 'output_path', defalue: ./ckpt/"
epoch_size: "the number of training epochs, defalue: 6000"
# eval options
eval_epoch_frq: "frequency to evaluate model, defalue: 20"
self_ensemble: "set True if wanna do self-ensemble while evaluating, defalue: True"
save_sr: "set True if wanna save sr and hr image while evaluating, defalue: True"
# opt options
opt_type: "optimizer type, choice from ['Adam'], defalue: Adam"
weight_decay: "weight_decay for optimizer, defalue: 0.0"
# learning rate options
learning_rate: "learning rate, defalue: 0.0001"
milestones: "the key epoch to do a gamma decay, defalue: [4000]"
gamma: "gamma decay rate, defalue: 0.5"
# dataset options
dataset_name: "dataset name, defalue: DIV2K"
lr_type: "lr image degeneration type, choice from ['bicubic', 'unknown'], defalue: bicubic"
batch_size: "batch size for training; total batch size = 16 is recommended, defalue: 2"
patch_size: "cut hr images into patch size for training, lr images auto-adjust by 'scale', defalue: 192"
scale: "scale for super resolution reconstruction, choice from [2,3,4], defalue: 4"
dataset_sink_mode: "set True if wanna using dataset sink mode, defalue: True"
need_unzip_in_modelarts: "set True if wanna unzip data after download data from s3, defalue: False"
need_unzip_files: "list of zip files to unzip, only work while 'need_unzip_in_modelarts'=True"
# net options
pre_trained: "load pre-trained model, x2/x3/x4 models can be loaded for each other, choice from [[S3_ABS_PATH], [RELATIVE_PATH below 'output_path'], [LOCAL_ABS_PATH], ''], defalue: ''"
rgb_range: "pix value range, defalue: 255"
rgb_mean: "rgb mean, defalue: [0.4488, 0.4371, 0.4040]"
rgb_std: "rgb standard deviation, defalue: [1.0, 1.0, 1.0]"
n_colors: "the number of RGB image channels, defalue: 3"
n_feats: "the number of output features for each Conv2d, defalue: 256"
kernel_size: "kernel size for Conv2d, defalue: 3"
n_resblocks: "the number of resblocks, defalue: 32"
res_scale: "zoom scale of res branch, defalue: 0.1"

View File

@ -1,4 +1,6 @@
目录
# 目录
[View English](./README.md)
<!-- TOC -->
@ -6,22 +8,27 @@
- [EDSR描述](#EDSR描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [特性](#特性)
- [混合精度](#混合精度)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [训练](#训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [导出](#导出)
- [导出脚本](#导出脚本)
- [推理过程](#推理过程)
- [推理](#推理)
- [在昇腾310上使用DIV2K数据集进行推理](#在昇腾310上使用DIV2K数据集进行推理)
- [在昇腾310上使用其他数据集进行推理](#在昇腾310上使用其他数据集进行推理)
- [模型描述](#模型描述)
- [性能](#性能)
- [训练性能](#训练性能)
- [DIV2K上的EDSR](#DIV2K上的EDSR)
- [DIV2K上的训练2倍/3倍/4倍超分辨率重建的EDSR](#DIV2K上的训练2倍/3倍/4倍超分辨率重建的EDSR)
- [评估性能](#评估性能)
- [Set5,Set14,B100,Urban100上的EDSR](#Set5,Set14,B100,Urban100上的EDSR)
- [DIV2K上的评估2倍/3倍/4倍超分辨率重建的EDSR](#DIV2K上的评估2倍/3倍/4倍超分辨率重建的EDSR)
- [推理性能](#推理性能)
- [DIV2K上的推理2倍/3倍/4倍超分辨率重建的EDSR](#DIV2K上的推理2倍/3倍/4倍超分辨率重建的EDSR)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
@ -29,113 +36,144 @@
# EDSR描述
EDSR是2017年提出的32层深度网络在2017年图像恢复和增强的新趋势研讨会上的超分挑战NTIRE2017 Super-Resolution Challenge中获得第一名。 EDSR相比于SRResNet减少了每个残差块中的batch normalization层,SRResNet相对于原本的ResNet则在每个残差块的出口减去了ReLU层.
[论文](https://arxiv.org/abs/1707.02921)Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **"Enhanced Deep Residual Networks for Single Image Super-Resolution,"** *2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**.
增强的深度超分辨率网络(EDSR)是2017年提出的单图超分辨重建网络在NTIRE2017超分辨重建比赛中获取第一名。它通过删除传统剩余网络中不必要的模块BatchNorm扩大模型的大小同时应用了稳定训练的方法进行优化显著提升了性能。
[论文](https://arxiv.org/pdf/1707.02921.pdf)Lim B , Son S , Kim H , et al. Enhanced Deep Residual Networks for Single Image Super-Resolution[C]// 2017 IEEE Conference on Computer Vision and Pattern Recognition Workshops (CVPRW). IEEE, 2017.
# 模型架构
EDSR先经过1次卷积层,再串联32个残差模块,再经过1次卷积层,最后上采样并卷积
EDSR是由多个优化后的residual blocks串联而成相比原始版本的residual blocksEDSR的residual blocks删除了BatchNorm层和最后一个ReLU层。删除BatchNorm使网络降低了40%的显存使用率和获得更快的计算效率从而可以增加网络深度和宽度。EDSR的主干模式使用32个residual blocks堆叠而成每个卷积层的卷积核数量256个Residual scaling是0.1损失函数是L1
# 数据集
使用的数据集:[DIV2K](<http://www.vision.ee.ethz.ch/~timofter/publications/Agustsson-CVPRW-2017.pdf>)
使用的数据集:[DIV2K](<https://data.vision.ee.ethz.ch/cvl/DIV2K/>)
- 数据集大小7.11G
- 数据集大小7.11G共1000组HR,LRx2,LRx3,LRx4有效彩色图像
- 训练集6.01G共800组图像
- 验证集783.68M共100组图像
- 测试集349.53M共100组图像(无HR图)
- 数据格式PNG图片文件文件
- 注数据将在src/dataset.py中处理。
- 训练集共800张图像采用了前800张进行训练
- 测试集共100张图像
# 特性
- 数据格式png文件
## 混合精度
- 注数据将在src/data/DIV2K.py中处理。
```shell
DIV2K
├── DIV2K_test_LR_bicubic
│   ├── X2
│   │   ├── 0901x2.png
│ │ ├─ ...
│   │   └── 1000x2.png
│   ├── X3
│   │   ├── 0901x3.png
│ │ ├─ ...
│   │   └── 1000x3.png
│   └── X4
│   ├── 0901x4.png
│ ├─ ...
│   └── 1000x4.png
├── DIV2K_test_LR_unknown
│   ├── X2
│   │   ├── 0901x2.png
│ │ ├─ ...
│   │   └── 1000x2.png
│   ├── X3
│   │   ├── 0901x3.png
│ │ ├─ ...
│   │   └── 1000x3.png
│   └── X4
│   ├── 0901x4.png
│ ├─ ...
│   └── 1000x4.png
├── DIV2K_train_HR
│   ├── 0001.png
│ ├─ ...
│   └── 0900.png
├── DIV2K_train_LR_bicubic
│   ├── X2
│   │   ├── 0001x2.png
│ │ ├─ ...
│   │   └── 0900x2.png
│   ├── X3
│   │   ├── 0001x3.png
│ │ ├─ ...
│   │   └── 0900x3.png
│   └── X4
│   ├── 0001x4.png
│ ├─ ...
│   └── 0900x4.png
└── DIV2K_train_LR_unknown
├── X2
│   ├── 0001x2.png
│ ├─ ...
│   └── 0900x2.png
├── X3
│   ├── 0001x3.png
│ ├─ ...
│   └── 0900x3.png
└── X4
├── 0001x4.png
├─ ...
└── 0900x4.png
```
采用[混合精度](https://www.mindspore.cn/docs/programming_guide/zh-CN/master/enable_mixed_precision.html?highlight=%E6%B7%B7%E5%90%88%E7%B2%BE%E5%BA%A6)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
以FP16算子为例如果输入数据类型为FP32MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志搜索“reduce precision”查看精度降低的算子。
# 环境要求
- 硬件Ascend
- 使用ascend处理器来搭建硬件环境。
- 使用Ascend处理器来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
- [MindSpore教程](https://www.mindspore.cn/docs/programming_guide/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估。对于分布式训练需要提前创建JSON格式的hccl配置文件。请遵循以下链接中的说明
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>
```shell
#单卡训练
sh run_ascend_standalone.sh [TRAIN_DATA_DIR]
```
- Ascend-910处理器环境运行单卡训练DIV2K
```shell
#分布式训练
sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
```
```python
# 运行训练示例(EDSR(x2) in the paper)
python train.py --batch_size 16 --scale 2 --config_path DIV2K_config.yaml > train.log 2>&1 &
# 运行训练示例(EDSR(x3) in the paper - from EDSR(x2))
python train.py --batch_size 16 --scale 3 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x2 model path] train.log 2>&1 &
# 运行训练示例(EDSR(x4) in the paper - from EDSR(x2))
python train.py --batch_size 16 --scale 4 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x2 model path] train.log 2>&1 &
```
- Ascend-910处理器环境运行8卡训练DIV2K
```python
# 运行分布式训练示例(EDSR(x2) in the paper)
bash scripts/run_train.sh rank_table.json --scale 2 --config_path DIV2K_config.yaml
# 运行分布式训练示例(EDSR(x3) in the paper)
bash scripts/run_train.sh rank_table.json --scale 3 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x2 model path]
# 运行分布式训练示例(EDSR(x4) in the paper)
bash scripts/run_train.sh rank_table.json --scale 4 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x2 model path]
```
- Ascend-910处理器环境运行单卡评估DIV2K
```python
# 运行评估示例(EDSR(x2) in the paper)
python eval.py --scale 2 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x2 model path] > train.log 2>&1 &
# 运行评估示例(EDSR(x3) in the paper)
python eval.py --scale 3 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x3 model path] > train.log 2>&1 &
# 运行评估示例(EDSR(x4) in the paper)
python eval.py --scale 4 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x4 model path] > train.log 2>&1 &
```
- Ascend-910处理器环境运行8卡评估DIV2K
```python
# 运行分布式评估示例(EDSR(x2) in the paper)
bash scripts/run_eval.sh rank_table.json --scale 2 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x2 model path]
# 运行分布式评估示例(EDSR(x3) in the paper)
bash scripts/run_eval.sh rank_table.json --scale 3 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x3 model path]
# 运行分布式评估示例(EDSR(x4) in the paper)
bash scripts/run_eval.sh rank_table.json --scale 4 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x4 model path]
```
- Ascend-910处理器环境运行单卡评估benchmark
```python
# 运行评估示例(EDSR(x2) in the paper)
python eval.py --scale 2 --config_path benchmark_config.yaml --pre_trained [pre-trained EDSR_x2 model path] > train.log 2>&1 &
# 运行评估示例(EDSR(x3) in the paper)
python eval.py --scale 3 --config_path benchmark_config.yaml --pre_trained [pre-trained EDSR_x3 model path] > train.log 2>&1 &
# 运行评估示例(EDSR(x4) in the paper)
python eval.py --scale 4 --config_path benchmark_config.yaml --pre_trained [pre-trained EDSR_x4 model path] > train.log 2>&1 &
```
- Ascend-910处理器环境运行8卡评估benchmark
```python
# 运行分布式评估示例(EDSR(x2) in the paper)
bash scripts/run_eval.sh rank_table.json --scale 2 --config_path benchmark_config.yaml --pre_trained [pre-trained EDSR_x2 model path]
# 运行分布式评估示例(EDSR(x3) in the paper)
bash scripts/run_eval.sh rank_table.json --scale 3 --config_path benchmark_config.yaml --pre_trained [pre-trained EDSR_x3 model path]
# 运行分布式评估示例(EDSR(x4) in the paper)
bash scripts/run_eval.sh rank_table.json --scale 4 --config_path benchmark_config.yaml --pre_trained [pre-trained EDSR_x4 model path]
```
- Ascend-310处理器环境运行单卡评估DIV2K
```python
# 运行推理命令
bash scripts/run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [SCALE] [LOG_FILE] [DEVICE_ID]
# 运行推理示例(EDSR(x2) in the paper)
bash scripts/run_infer_310.sh ./mindir/EDSR_x2_DIV2K-6000_50_InputSize1020.mindir ./DIV2K 2 ./infer_x2.log 0
# 运行推理示例(EDSR(x3) in the paper)
bash scripts/run_infer_310.sh ./mindir/EDSR_x3_DIV2K-6000_50_InputSize680.mindir ./DIV2K 3 ./infer_x3.log 0
# 运行推理示例(EDSR(x4) in the paper)
bash scripts/run_infer_310.sh ./mindir/EDSR_x4_DIV2K-6000_50_InputSize510.mindir ./DIV2K 4 ./infer_x4.log 0
```
- 在 ModelArts 上训练 DIV2K 数据集
如果你想在modelarts上运行可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/)
```python
#评估
sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
# (1) 选择上传代码到 S3 桶
# 选择代码目录/s3_path_to_code/EDSR/
# 选择启动文件/s3_path_to_code/EDSR/train.py
# (2) 在网页上设置参数, DIV2K_config.yaml中的参数均可以在网页上配置
# scale = 2
# config_path = /local_path_to_code/DIV2K_config.yaml
# enable_modelarts = True
# pre_trained = [模型s3地址] 或者 [不设置]
# [其他参数] = [参数值]
# (3) 上传DIV2K数据集到 S3 桶上, 配置"训练数据集"路径,如果未解压,可以在(2)中配置
# need_unzip_in_modelarts = True
# (4) 在网页上设置"训练输出文件路径"、"作业日志路径"等
# (5) 选择8卡/单卡机器,创建训练作业
```
# 脚本说明
@ -144,150 +182,218 @@ sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
```bash
├── model_zoo
├── README.md // 所有模型相关说明
├── EDSR
├── README_CN.md //自述文件
├── eval.py //评估脚本
├── export.py //导出脚本
├── script
│ ├── run_ascend_distribute.sh //Ascend分布式训练shell脚本
│ ├── run_ascend_standalone.sh //Ascend单卡训练shell脚本
└── run_eval.sh //eval验证shell脚本
├── README_CN.md // EDSR说明
├── model_utils // 上云的工具脚本
├── DIV2K_config.yaml // EDSR参数
├── scripts
│ ├──run_train.sh // 分布式到Ascend的shell脚本
│ ├──run_eval.sh // Ascend评估的shell脚本
├──run_infer_310.sh // Ascend-310推理shell脚本
├── src
│ ├── args.py //超参数
│ ├── common.py //公共网络模块
│ ├── data
│ │ ├── common.py //公共数据集
│ │ ├── div2k.py //div2k数据集
│ │ └── srdata.py //所有数据集
│ ├── metrics.py //PSNR和SSIM计算器
│ ├── model.py //EDSR网络
│ └── utils.py //训练脚本
└── train.py //训练脚本
│ ├──dataset.py // 创建数据集
│ ├──edsr.py // edsr网络架构
│ ├──config.py // 参数配置
│ ├──metric.py // 评估指标
│ ├──utils.py // train.py/eval.py公用的代码段
├── train.py // 训练脚本
├── eval.py // 评估脚本
├── export.py // 将checkpoint文件导出到air/mindir
├── preprocess.py // Ascend-310推理的数据预处理脚本
├── ascend310_infer
│ ├──src // 实现Ascend-310推理源代码
│ ├──inc // 实现Ascend-310推理源代码
│ ├──build.sh // 构建Ascend-310推理程序的shell脚本
│ ├──CMakeLists.txt // 构建Ascend-310推理程序的CMakeLists
├── postprocess.py // Ascend-310推理的数据后处理脚本
```
## 脚本参数
主要参数如下:
在DIV2K_config.yaml中可以同时配置训练参数和评估参数。benchmark_config.yaml中的同名参数是一样的定义。
```python
-h, --help show this help message and exit
--dir_data DIR_DATA dataset directory
--data_train DATA_TRAIN
train dataset name
--data_test DATA_TEST
test dataset name
--data_range DATA_RANGE
train/test data range
--ext EXT dataset file extension
--scale SCALE super resolution scale
--patch_size PATCH_SIZE
output patch size
--rgb_range RGB_RANGE
maximum value of RGB
--n_colors N_COLORS number of color channels to use
--no_augment do not use data augmentation
--model MODEL model name
--n_resblocks N_RESBLOCKS
number of residual blocks
--n_feats N_FEATS number of feature maps
--res_scale RES_SCALE
residual scaling
--test_every TEST_EVERY
do test per every N batches
--epochs EPOCHS number of epochs to train
--batch_size BATCH_SIZE
input batch size for training
--test_only set this option to test the model
--lr LR learning rate
--ckpt_save_path CKPT_SAVE_PATH
path to save ckpt
--ckpt_save_interval CKPT_SAVE_INTERVAL
save ckpt frequency, unit is epoch
--ckpt_save_max CKPT_SAVE_MAX
max number of saved ckpt
--ckpt_path CKPT_PATH
path of saved ckpt
--task_id TASK_ID
- 可以使用以下语句可以打印配置说明
```python
python train.py --config_path DIV2K_config.yaml --help
```
- 可以直接查看DIV2K_config.yaml内的配置说明说明如下
```yaml
enable_modelarts: "在云道运行则需要配置为True, default: False"
data_url: "云道数据路径"
train_url: "云道代码路径"
checkpoint_url: "云道保存的路径"
data_path: "运行机器的数据路径由脚本从云道数据路径下载default: /cache/data"
output_path: "运行机器的输出路径由脚本从本地上传至checkpoint_urldefault: /cache/train"
device_target: "可选['Ascend']default: Ascend"
amp_level: "可选['O0', 'O2', 'O3', 'auto']default: O3"
loss_scale: "除了O0外其他混合精度时会做loss放缩default: 1000.0"
keep_checkpoint_max: "最多保存多少个ckpt defalue: 60"
save_epoch_frq: "每隔多少epoch保存ckpt一次 defalue: 100"
ckpt_save_dir: "保存的本地相对路径根目录是output_path defalue: ./ckpt/"
epoch_size: "训练多少个epoch defalue: 6000"
eval_epoch_frq: "训练时每隔多少epoch执行一次验证defalue: 20"
self_ensemble: "验证时执行self_ensemble仅在eval.py中使用 defalue: True"
save_sr: "验证时保存sr和hr图片仅在eval.py中使用 defalue: True"
opt_type: "优化器类型,可选['Adam']defalue: Adam"
weight_decay: "优化器权重衰减参数defalue: 0.0"
learning_rate: "学习率defalue: 0.0001"
milestones: "学习率衰减的epoch节点列表defalue: [4000]"
gamma: "学习率衰减率defalue: 0.5"
dataset_name: "数据集名称defalue: DIV2K"
lr_type: "lr图的退化方式可选['bicubic', 'unknown']defalue: bicubic"
batch_size: "为了保证效果建议8卡用2单卡用16defalue: 2"
patch_size: "训练时候的裁剪HR图大小LR图会依据scale调整裁剪大小defalue: 192"
scale: "模型的超分辨重建的尺度,可选[2,3,4], defalue: 4"
dataset_sink_mode: "训练使用数据下沉模式defalue: True"
need_unzip_in_modelarts: "从s3下载数据后加压数据defalue: False"
need_unzip_files: "需要解压的数据列表, need_unzip_in_modelarts=True时才起作用"
pre_trained: "加载预训练模型x2/x3/x4倍可以相互加载可选[[s3绝对地址], [output_path下相对地址], [本地机器绝对地址], '']defalue: ''"
rgb_range: "图片像素的范围defalue: 255"
rgb_mean: "图片RGB均值defalue: [0.4488, 0.4371, 0.4040]"
rgb_std: "图片RGB方差defalue: [1.0, 1.0, 1.0]"
n_colors: "RGB图片3通道defalue: 3"
n_feats: "每个卷积层的输出特征数量defalue: 256"
kernel_size: "卷积核大小defalue: 3"
n_resblocks: "resblocks数量defalue: 32"
res_scale: "res的分支的系数defalue: 0.1"
```
## 导出
在运行推理之前我们需要先导出模型。Air模型只能在昇腾910环境上导出mindir可以在任意环境上导出。batch_size只支持1。
### 导出脚本
```shell
python export.py --config_path DIV2K_config.yaml --output_path [dir to save model] --scale [SCALE] --pre_trained [pre-trained EDSR_x[SCALE] model path]
```
## 训练过程
## 推理过程
### 训练
### 推理
- Ascend处理器环境运行
#### 在昇腾310上使用DIV2K数据集进行推理
```bash
sh run_ascend_standalone.sh [TRAIN_DATA_DIR]
- 推理脚本
```shell
bash scripts/run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [SCALE] [LOG_FILE] [DEVICE_ID]
```
如果数据集保存路径为G:\DIV2K`TRAIN_DATA_DIR`应传入G:\。
- 范例
上述python命令将在后台运行您可以通过train.log文件查看结果。
### 分布式训练
- Ascend处理器环境运行
```bash
sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
```shell
# 运行推理示例(EDSR(x2) in the paper)
bash scripts/run_infer_310.sh ./mindir/EDSR_x2_DIV2K-6000_50_InputSize1020.mindir ./DIV2K 2 ./infer_x2.log 0
# 运行推理示例(EDSR(x3) in the paper)
bash scripts/run_infer_310.sh ./mindir/EDSR_x3_DIV2K-6000_50_InputSize680.mindir ./DIV2K 3 ./infer_x3.log 0
# 运行推理示例(EDSR(x4) in the paper)
bash scripts/run_infer_310.sh ./mindir/EDSR_x4_DIV2K-6000_50_InputSize510.mindir ./DIV2K 4 ./infer_x4.log 0
```
如果数据集保存路径为G:\DIV2K`TRAIN_DATA_DIR`应传入G:\。
- 推理指标分别查看infer_x2.log、infer_x3.log、infer_x4.log可以看到
## 评估过程
```python
# EDSR(x2) in the paper
evaluation result = {'psnr': 35.068791459971266}
# EDSR(x3) in the paper
evaluation result = {'psnr': 31.386362838892456}
# EDSR(x4) in the paper
evaluation result = {'psnr': 29.38072897971985}
```
### 评估
#### 在昇腾310上使用其他数据集进行推理
在运行以下命令之前,请检查用于评估的检查点路径。
- 推理流程
```bash
sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
# (1) 整理数据集lr图片统一padding到一个固定尺寸。参考preprocess.py
# (2) 根据固定尺寸导出模型参考export.py
# (3) 使用build.sh在ascend310_infer文件夹内编译推理程序得到程序ascend310_infer/out/main
# (4) 配置数据集图片路径模型路径输出路径等使用main推理得到超分辨率重建图片。
./ascend310_infer/out/main --mindir_path=[model] --dataset_path=[read_data_path] --device_id=[device_id] --save_dir=[save_data_path]
# (5) 后处理图片去除padding的无效区域。和hr图一起统计指标。参考preprocess.py
```
`DATASET_TYPE`可选 ["Set5", "Set14", "B100", "Urban100", "DIV2K"]
如果数据集保存路径为G:\DIV2K或者G:\Set5或者G:\Set14或者G:\B100或者G:\Urban100`TRAIN_DATA_DIR`应传入G:\。
您可以通过log.txt文件查看结果。
# 模型描述
## 性能
### 训练性能
| 参数 | Ascend |
| ------------- | ------------------------------------------------------------ |
| 资源 | Ascend 910 |
| 上传日期 | 2021-7-4 |
| MindSpore版本 | 1.2.0 |
| 数据集 | DIV2K |
| 训练参数 | epoch=1000, steps=1000, batch_size =16, lr=0.0001 |
| 优化器 | Adam |
| 损失函数 | L1 |
| 输出 | 超分辨率图片 |
| 损失 | 3.1 |
| 速度 | 8卡50.75毫秒/步 |
| 总时长 | 8卡12.865小时 |
| 微调检查点 | 466.13 MB (.ckpt文件) |
| 脚本 | [EDSR](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/EDSR) |
#### DIV2K上的训练2倍/3倍/4倍超分辨率重建的EDSR
| 参数 | Ascend | Ascend | Ascend |
| --- | --- | --- | --- |
| 模型版本 | EDSR(x2) | EDSR(x3) | EDSR(x4) |
| 资源 | Ascend 910CPU 2.60GHz192核内存 755G系统 Euler2.8 | Ascend 910CPU 2.60GHz192核内存 755G系统 Euler2.8 | Ascend 910CPU 2.60GHz192核内存 755G系统 Euler2.8 |
| 上传日期 | 2021-09-01 | 2021-09-01 | 2021-09-01 |
| MindSpore版本 | 1.2.0 | 1.2.0 | 1.2.0 |
| 数据集 | DIV2K | DIV2K | DIV2K |
| 训练参数 | epoch=6000, 总batch_size=16, lr=0.0001, patch_size=192 | epoch=6000, 总batch_size=16, lr=0.0001, patch_size=192 | epoch=6000, 总batch_size=16, lr=0.0001, patch_size=192 |
| 优化器 | Adam | Adam | Adam |
| 损失函数 | L1 | L1 | L1 |
| 输出 | 超分辨率重建RGB图 | 超分辨率重建RGB图 | 超分辨率重建RGB图 |
| 损失 | 4.06 | 4.01 | 4.50 |
| 速度 | 1卡16.5秒/epoch8卡2.76秒/epoch | 1卡21.6秒/epoch8卡1.8秒/epoch | 1卡21.0秒/epoch8卡1.8秒/epoch |
| 总时长 | 单卡1725分钟; 8卡310分钟 | 单卡2234分钟; 8卡217分钟 | 单卡2173分钟; 8卡210分钟 |
| 参数(M) | 40.73M | 43.68M | 43.09M |
| 微调检查点 | 467.28 MB (.ckpt文件) | 501.04 MB (.ckpt文件) | 494.29 MB (.ckpt文件) |
### 评估性能
| 参数 | Ascend |
| ------------- | ----------------------------------------------------------- |
| 资源 | Ascend 910 |
| 上传日期 | 2021-7-4 |
| MindSpore版本 | 1.2.0 |
| 数据集 | Set5,Set14,B100,Urban100 |
| batch_size | 1 |
| 输出 | 超分辨率图片 |
| PSNR | Set5:38.2136, Set14:34.0081, B100:32.3590, Urban100:33.0162 |
#### DIV2K上的评估2倍/3倍/4倍超分辨率重建的EDSR
| 参数 | Ascend | Ascend | Ascend |
| --- | --- | --- | --- |
| 模型版本 | EDSR(x2) | EDSR(x3) | EDSR(x4) |
| 资源 | Ascend 910系统 Euler2.8 | Ascend 910系统 Euler2.8 | Ascend 910系统 Euler2.8 |
| 上传日期 | 2021-09-01 | 2021-09-01 | 2021-09-01 |
| MindSpore版本 | 1.2.0 | 1.2.0 | 1.2.0 |
| 数据集 | DIV2K, 100张图像 | DIV2K, 100张图像 | DIV2K, 100张图像 |
| self_ensemble | True | True | True |
| batch_size | 1 | 1 | 1 |
| 输出 | 超分辨率重建RGB图 | 超分辨率重建RGB图 | 超分辨率重建RGB图 |
| Set5 psnr | 38.275 db | 34.777 db | 32.618 db |
| Set14 psnr | 34.059 db | 30.684 db | 28.928 db |
| B100 psnr | 32.393 db | 29.332 db | 27.792 db |
| Urban100 psnr | 32.970 db | 29.019 db | 26.849 db |
| DIV2K psnr | 35.063 db | 31.380 db | 29.370 db |
| 推理模型 | 467.28 MB (.ckpt文件) | 501.04 MB (.ckpt文件) | 494.29 MB (.ckpt文件) |
### 推理性能
#### DIV2K上的推理2倍/3倍/4倍超分辨率重建的EDSR
| 参数 | Ascend | Ascend | Ascend |
| --- | --- | --- | --- |
| 模型版本 | EDSR(x2) | EDSR(x3) | EDSR(x4) |
| 资源 | Ascend 310系统 ubuntu18.04 | Ascend 310系统 ubuntu18.04 | Ascend 310系统 ubuntu18.04 |
| 上传日期 | 2021-09-01 | 2021-09-01 | 2021-09-01 |
| MindSpore版本 | 1.2.0 | 1.2.0 | 1.2.0 |
| 数据集 | DIV2K, 100张图像 | DIV2K, 100张图像 | DIV2K, 100张图像 |
| self_ensemble | True | True | True |
| batch_size | 1 | 1 | 1 |
| 输出 | 超分辨率重建RGB图 | 超分辨率重建RGB图 | 超分辨率重建RGB图 |
| DIV2K psnr | 35.068 db | 31.386 db | 29.380 db |
| 推理模型 | 156 MB (.mindir文件) | 167 MB (.mindir文件) | 165 MB (.mindir文件) |
# 随机情况说明
在train.py中我们设置了“train_net”函数内的种子。
在train.pyeval.py中我们设置了mindspore.common.set_seed(2021)种子。
# ModelZoo主页
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,14 @@
cmake_minimum_required(VERSION 3.14.1)
project(Ascend310Infer)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT})
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main src/main.cc src/utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)

View File

@ -0,0 +1,23 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ ! -d out ]; then
mkdir out
fi
cd out || exit
cmake .. \
-DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make

View File

@ -0,0 +1,33 @@
/*
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs,
const std::string &homePath);
#endif

View File

@ -0,0 +1,141 @@
/*
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <sstream>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "include/dataset/vision_ascend.h"
#include "include/dataset/execute.h"
#include "include/dataset/vision.h"
#include "inc/utils.h"
using mindspore::Context;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
using mindspore::ModelType;
using mindspore::GraphCell;
using mindspore::kSuccess;
using mindspore::MSTensor;
using mindspore::dataset::Execute;
using mindspore::dataset::MapTargetDevice;
using mindspore::dataset::TensorTransform;
using mindspore::dataset::vision::Resize;
using mindspore::dataset::vision::HWC2CHW;
using mindspore::dataset::vision::Normalize;
using mindspore::dataset::vision::Decode;
using mindspore::dataset::vision::CenterCrop;
DEFINE_string(mindir_path, "", "mindir path");
DEFINE_string(dataset_path, ".", "dataset path");
DEFINE_string(save_dir, "", "save dir");
DEFINE_int32(device_id, 0, "device id");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_mindir_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
DIR *dir = OpenDir(FLAGS_save_dir);
if (dir == nullptr) {
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
ascend310->SetBufferOptimizeMode("off_optimize");
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
Model model;
Status ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
auto decode = Decode();
auto normalize = Normalize({0.0, 0.0, 0.0}, {1.0, 1.0, 1.0});
auto hwc2chw = HWC2CHW();
Execute transform({decode, normalize, hwc2chw});
auto all_files = GetAllFiles(FLAGS_dataset_path);
std::map<double, double> costTime_map;
size_t size = all_files.size();
for (size_t i = 0; i < size; ++i) {
struct timeval start = {0};
struct timeval end = {0};
double startTimeMs = 0.0;
double endTimeMs = 0.0;
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << all_files[i] << std::endl;
auto img = MSTensor();
auto image = ReadFileToTensor(all_files[i]);
transform(image, &img);
std::vector<MSTensor> model_inputs = model.GetInputs();
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
img.Data().get(), img.DataSize());
gettimeofday(&start, nullptr);
ret = model.Predict(inputs, &outputs);
gettimeofday(&end, nullptr);
if (ret != kSuccess) {
std::cout << "Predict " << all_files[i] << " failed." << std::endl;
return 1;
}
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
WriteResult(all_files[i], outputs, FLAGS_save_dir);
}
double average = 0.0;
int inferCount = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
inferCount++;
}
average = average / inferCount;
std::stringstream timeCost;
timeCost << "NN inference cost average time: " << average << " ms of infer_count " << inferCount << std::endl;
std::cout << "NN inference cost average time: " << average << "ms of infer_count " << inferCount << std::endl;
std::string fileName = FLAGS_save_dir + std::string("/test_perform_static.txt");
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
fileStream << timeCost.str();
fileStream.close();
costTime_map.clear();
return 0;
}

View File

@ -0,0 +1,144 @@
/*
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/utils.h"
#include <fstream>
#include <algorithm>
#include <iostream>
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> dirs;
std::vector<std::string> files;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == "..") {
continue;
} else if (filename->d_type == DT_DIR) {
dirs.emplace_back(std::string(dirName) + "/" + filename->d_name);
} else if (filename->d_type == DT_REG) {
files.emplace_back(std::string(dirName) + "/" + filename->d_name);
} else {
continue;
}
}
for (auto d : dirs) {
dir = OpenDir(d);
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
files.emplace_back(std::string(d) + "/" + filename->d_name);
}
}
std::sort(files.begin(), files.end());
for (auto &f : files) {
std::cout << "image file: " << f << std::endl;
}
return files;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs, const std::string &homePath) {
for (size_t i = 0; i < outputs.size(); ++i) {
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
size_t outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE * outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(
file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}

View File

@ -0,0 +1,69 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
device_target: "Ascend"
# ==============================================================================
ckpt_save_dir: "./ckpt/"
self_ensemble: True
save_sr: True
# dataset options
dataset_name: "benchmark"
scale: 4
need_unzip_in_modelarts: False
# net options
pre_trained: ""
rgb_range: 255
rgb_mean: [0.4488, 0.4371, 0.4040]
rgb_std: [1.0, 1.0, 1.0]
n_colors: 3
n_feats: 256
kernel_size: 3
n_resblocks: 32
res_scale: 0.1
---
# helper
enable_modelarts: "set True if run in modelarts, default: False"
# Url for modelarts
data_url: "modelarts data path"
train_url: "modelarts code path"
checkpoint_url: "modelarts checkpoint save path"
# Path for local
data_path: "local data path, data will be download from 'data_url', default: /cache/data"
output_path: "local output path, checkpoint will be upload to 'checkpoint_url', default: /cache/train"
device_target: "choice from ['Ascend'], default: Ascend"
# ==============================================================================
# train options
ckpt_save_dir: "the relative path to save checkpoint, root path is 'output_path', defalue: ./ckpt/"
self_ensemble: "set True if wanna do self-ensemble while evaluating, defalue: True"
save_sr: "set True if wanna save sr and hr image while evaluating, defalue: True"
# dataset options
dataset_name: "dataset name, defalue: DIV2K"
scale: "scale for super resolution reconstruction, choice from [2,3,4], defalue: 4"
need_unzip_in_modelarts: "set True if wanna unzip data after download data from s3, defalue: False"
need_unzip_files: "list of zip files to unzip, only work while 'need_unzip_in_modelarts'=True"
# net options
pre_trained: "load pre-trained model, x2/x3/x4 models can be loaded for each other, choice from [[S3_ABS_PATH], [RELATIVE_PATH below 'output_path'], [LOCAL_ABS_PATH], ''], defalue: ''"
rgb_range: "pix value range, defalue: 255"
rgb_mean: "rgb mean, defalue: [0.4488, 0.4371, 0.4040]"
rgb_std: "rgb standard deviation, defalue: [1.0, 1.0, 1.0]"
n_colors: "the number of RGB image channels, defalue: 3"
n_feats: "the number of output features for each Conv2d, defalue: 256"
kernel_size: "kernel size for Conv2d, defalue: 3"
n_resblocks: "the number of resblocks, defalue: 32"
res_scale: "zoom scale of res branch, defalue: 0.1"

View File

@ -12,64 +12,189 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""edsr eval script"""
"""
#################evaluate EDSR example on DIV2K########################
"""
import os
import numpy as np
import mindspore.dataset as ds
from mindspore import Tensor, context
from mindspore.common import dtype as mstype
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.args import args
import src.model as edsr
from src.data.srdata import SRData
from src.data.div2k import DIV2K
from src.metrics import calc_psnr, quantize, calc_ssim
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(max_call_depth=10000)
def eval_net():
"""eval"""
if args.epochs == 0:
args.epochs = 1e8
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False
if args.data_test[0] == 'DIV2K':
train_dataset = DIV2K(args, name=args.data_test, train=False, benchmark=False)
else:
train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False)
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
net_m = edsr.EDSR(args)
if args.ckpt_path:
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(net_m, param_dict)
net_m.set_train(False)
print('load mindspore net successfully.')
num_imgs = train_de_dataset.get_dataset_size()
psnrs = np.zeros((num_imgs, 1))
ssims = np.zeros((num_imgs, 1))
for batch_idx, imgs in enumerate(train_loader):
lr = imgs['LR']
hr = imgs['HR']
lr = Tensor(lr, mstype.float32)
pred = net_m(lr)
pred_np = pred.asnumpy()
pred_np = quantize(pred_np, 255)
psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0)
pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0)
hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0)
ssim = calc_ssim(pred_np, hr, args.scale[0])
print("current psnr: ", psnr)
print("current ssim: ", ssim)
psnrs[batch_idx, 0] = psnr
ssims[batch_idx, 0] = ssim
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0]))
import numpy as np
from mindspore.common import set_seed
from mindspore import Tensor, ops
from src.metric import SelfEnsembleWrapperNumpy, PSNR, SaveSrHr
from src.utils import init_env, init_dataset, init_net, modelarts_pre_process, do_eval
from src.dataset import get_rank_info, LrHrImages, hwc2chw, uint8_to_float32
from model_utils.config import config
from model_utils.moxing_adapter import moxing_wrapper
set_seed(2021)
class HrCutter:
"""
cut hr into correct shape, for evaluating benchmark
"""
def __init__(self, lr_scale):
self.lr_scale = lr_scale
def __call__(self, lr, hr):
lrh, lrw, _ = lr.shape
hrh, hrw, _ = hr.shape
h, w = lrh * self.lr_scale, lrw * self.lr_scale
if hrh != h or hrw != w:
hr = hr[0:h, 0:w, :]
return lr, hr
class RepeatDataSet:
"""
Repeat DataSet so that it can dist evaluate Set5
"""
def __init__(self, dataset, repeat):
self.dataset = dataset
self.repeat = repeat
def __getitem__(self, idx):
return self.dataset[idx % len(self.dataset)]
def __len__(self):
return len(self.dataset) * self.repeat
def create_dataset_benchmark(dataset_path, scale):
"""
create a train or eval benchmark dataset
Args:
dataset_path(string): the path of dataset.
scale(int): lr scale, read data ordered by it, choices=(2,3,4)
Returns:
multi_datasets
"""
lr_scale = scale
multi_datasets = {}
for dataset_name in ["Set5", "Set14", "B100", "Urban100"]:
# get HR_PATH/*.png
dir_hr = os.path.join(dataset_path, dataset_name, "HR")
hr_pattern = os.path.join(dir_hr, "*.png")
# get LR
column_names = [f"lrx{lr_scale}", "hr"]
dir_lr = os.path.join(dataset_path, dataset_name, "LR_bicubic", f"X{lr_scale}")
lr_pattern = os.path.join(dir_lr, f"*x{lr_scale}.png")
lrs_pattern = [lr_pattern]
device_num, rank_id = get_rank_info()
# make dataset
dataset = LrHrImages(lr_pattern=lrs_pattern, hr_pattern=hr_pattern)
if len(dataset) < device_num:
dataset = RepeatDataSet(dataset, repeat=device_num // len(dataset) + 1)
# make mindspore dataset
if device_num == 1 or device_num is None:
generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names,
num_parallel_workers=3,
shuffle=False)
else:
sampler = ds.DistributedSampler(num_shards=device_num, shard_id=rank_id, shuffle=False, offset=0)
generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names,
num_parallel_workers=3,
sampler=sampler)
# define map operations
transform_img = [
HrCutter(lr_scale),
hwc2chw,
uint8_to_float32,
]
# pre-process hr lr
generator_dataset = generator_dataset.map(input_columns=column_names,
output_columns=column_names,
column_order=column_names,
operations=transform_img)
# apply batch operations
generator_dataset = generator_dataset.batch(1, drop_remainder=False)
multi_datasets[dataset_name] = generator_dataset
return multi_datasets
class BenchmarkPSNR(PSNR):
"""
eval psnr for Benchmark
"""
def __init__(self, rgb_range, shave, channels_scale):
super(BenchmarkPSNR, self).__init__(rgb_range=rgb_range, shave=shave)
self.channels_scale = channels_scale
self.c_scale = Tensor(np.array(self.channels_scale, dtype=np.float32).reshape((1, -1, 1, 1)))
self.sum = ops.ReduceSum(keep_dims=True)
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('PSNR need 2 inputs (sr, hr), but got {}'.format(len(inputs)))
sr, hr = inputs
sr = self.quantize(sr)
diff = (sr - hr) / self.rgb_range
diff = diff * self.c_scale
valid = self.sum(diff, 1)
if self.shave is not None and self.shave != 0:
valid = valid[..., self.shave:(-self.shave), self.shave:(-self.shave)]
mse_list = (valid ** 2).mean(axis=(1, 2, 3))
mse_list = self._convert_data(mse_list).tolist()
psnr_list = [float(1e32) if mse == 0 else(- 10.0 * math.log10(mse)) for mse in mse_list]
self._accumulate(psnr_list)
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_eval():
"""
run eval
"""
print(config)
cfg = config
init_env(cfg)
net = init_net(cfg)
eval_net = SelfEnsembleWrapperNumpy(net) if cfg.self_ensemble else net
if cfg.dataset_name == "DIV2K":
cfg.batch_size = 1
cfg.patch_size = -1
ds_val = init_dataset(cfg, "valid")
metrics = {
"psnr": PSNR(rgb_range=cfg.rgb_range, shave=6 + cfg.scale),
}
if config.save_sr:
save_img_dir = os.path.join(cfg.output_path, "HrSr")
os.makedirs(save_img_dir, exist_ok=True)
metrics["num_sr"] = SaveSrHr(save_img_dir)
do_eval(eval_net, ds_val, metrics)
print("eval success", flush=True)
elif cfg.dataset_name == "benchmark":
multi_datasets = create_dataset_benchmark(dataset_path=cfg.data_path, scale=cfg.scale)
result = {}
for dname, ds_val in multi_datasets.items():
dpnsr = f"{dname}_psnr"
gray_coeffs = [65.738, 129.057, 25.064]
channels_scale = [x / 256.0 for x in gray_coeffs]
metrics = {
dpnsr: BenchmarkPSNR(rgb_range=cfg.rgb_range, shave=cfg.scale, channels_scale=channels_scale)
}
if config.save_sr:
save_img_dir = os.path.join(cfg.output_path, "HrSr", dname)
os.makedirs(save_img_dir, exist_ok=True)
metrics["num_sr"] = SaveSrHr(save_img_dir)
result[dpnsr] = do_eval(eval_net, ds_val, metrics)[dpnsr]
if get_rank_id() == 0:
print(result, flush=True)
print("eval success", flush=True)
else:
raise RuntimeError("Unsupported dataset.")
if __name__ == '__main__':
print("Start eval function!")
eval_net()
run_eval()

View File

@ -12,50 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export net together with checkpoint into air/mindir/onnx models"""
"""
##############export checkpoint file into air, mindir models#################
python export.py
"""
import os
import argparse
import numpy as np
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
import src.model as edsr
import mindspore as ms
from mindspore import Tensor, export, context
parser = argparse.ArgumentParser(description='edsr export')
parser.add_argument("--ckpt_path", type=str, required=True, help="path of checkpoint file")
parser.add_argument("--file_name", type=str, default="edsr", help="output file name.")
parser.add_argument("--file_format", type=str, default="MINDIR", choices=['MINDIR', 'AIR', 'ONNX'], help="file format")
parser.add_argument('--scale', type=str, default='2', help='super resolution scale')
parser.add_argument('--rgb_range', type=int, default=255, help='maximum value of RGB')
parser.add_argument('--n_colors', type=int, default=3, help='number of color channels to use')
parser.add_argument('--n_resblocks', type=int, default=32, help='number of residual blocks')
parser.add_argument('--n_feats', type=int, default=256, help='number of feature maps')
parser.add_argument('--res_scale', type=float, default=0.1, help='residual scaling')
parser.add_argument('--task_id', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=1)
args1 = parser.parse_args()
args1.scale = [int(x) for x in args1.scale.split("+")]
for arg in vars(args1):
if vars(args1)[arg] == 'True':
vars(args1)[arg] = True
elif vars(args1)[arg] == 'False':
vars(args1)[arg] = False
def run_export(args):
"""run_export"""
device_id = int(os.getenv("DEVICE_ID", '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
net = edsr.EDSR(args)
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(net, param_dict)
net.set_train(False)
print('load mindspore net and checkpoint successfully.')
inputs = Tensor(np.zeros([args.batch_size, 3, 678, 1020], np.float32))
export(net, inputs, file_name=args.file_name, file_format=args.file_format)
print('export successfully!')
from src.utils import init_net
from model_utils.config import config
from model_utils.device_adapter import get_device_id
from model_utils.moxing_adapter import moxing_wrapper
if __name__ == "__main__":
run_export(args1)
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
context.set_context(device_id=get_device_id())
MAX_HR_SIZE = 2040
@moxing_wrapper()
def run_export():
"""
run export
"""
print(config)
cfg = config
if cfg.pre_trained is None:
raise RuntimeError('config.pre_trained is None.')
net = init_net(cfg)
max_lr_size = MAX_HR_SIZE // cfg.scale
input_arr = Tensor(np.ones([1, cfg.n_colors, max_lr_size, max_lr_size]), ms.float32)
file_name = os.path.splitext(os.path.basename(cfg.pre_trained))[0]
file_name = file_name + f"_InputSize{max_lr_size}"
file_path = os.path.join(cfg.output_path, file_name)
file_format = 'MINDIR'
num_params = sum([param.size for param in net.parameters_dict().values()])
export(net, input_arr, file_name=file_path, file_format=file_format)
print(f"export success", flush=True)
print(f"{cfg.pre_trained} -> {file_path}.{file_format.lower()}, net parameters = {num_params/1000000:>0.4}M",
flush=True)
if __name__ == '__main__':
run_export()

View File

@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.edsr import EDSR
def edsr(*args, **kwargs):
return EDSR(*args, **kwargs)
def create_network(name, *args, **kwargs):
if name == "edsr":
return edsr(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,128 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
final_config = Config(final_config)
return final_config
config = get_config()

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from .config import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,37 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,119 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from .config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
# Run the main function
run_func(*args, **kwargs)
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -0,0 +1,191 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''post process for 310 inference'''
import os
import math
from PIL import Image
import numpy as np
from mindspore import Tensor
from src.utils import init_env, modelarts_pre_process
from src.dataset import FolderImagePair, AUG_DICT
from src.metric import PSNR
from model_utils.config import config
from model_utils.moxing_adapter import moxing_wrapper
def read_bin(bin_path):
img = np.fromfile(bin_path, dtype=np.float32)
num_pix = img.size
img_shape = int(math.sqrt(num_pix // 3))
if 1 * 3 * img_shape * img_shape != num_pix:
raise RuntimeError(f'bin file error, it not output from edsr network, {bin_path}')
img = img.reshape(1, 3, img_shape, img_shape)
return img
def read_bin_as_hwc(bin_path):
nchw_img = read_bin(bin_path)
chw_img = np.squeeze(nchw_img)
hwc_img = chw_img.transpose(1, 2, 0)
return hwc_img
def unpadding(img, target_shape):
h, w = target_shape[0], target_shape[1]
img_h, img_w, _ = img.shape
if img_h > h:
img = img[:h, :, :]
if img_w > w:
img = img[:, :w, :]
return img
def img_to_tensor(img):
img = np.array([img.transpose(2, 0, 1)], np.float32)
img = Tensor(img)
return img
def float_to_uint8(img):
clip_img = np.clip(img, 0, 255)
round_img = np.round(clip_img)
uint8_img = round_img.astype(np.uint8)
return uint8_img
def bin_to_png(cfg):
"""
bin from ascend310_infer outputs will be covert to png
"""
dataset_path = cfg.data_path
dataset_type = "valid"
aug_keys = list(AUG_DICT.keys())
lr_scale = cfg.scale
if cfg.self_ensemble:
dir_sr_bin = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR_bin", f"X{lr_scale}")
save_sr_se_dir = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR_self_ensemble", f"X{lr_scale}")
if os.path.isdir(dir_sr_bin):
os.makedirs(save_sr_se_dir, exist_ok=True)
bin_patterns = [os.path.join(dir_sr_bin, f"*x{lr_scale}_{a_key}_0.bin") for a_key in aug_keys]
dataset = FolderImagePair(bin_patterns, reader=read_bin_as_hwc)
for i in range(len(dataset)):
img_key = dataset.get_key(i)
sr_se_path = os.path.join(save_sr_se_dir, f"{img_key}x{lr_scale}.png")
if os.path.isfile(sr_se_path):
continue
data = dataset[i]
img_key, sr_8 = data[0], data[1:]
sr = np.zeros_like(sr_8[0], dtype=np.float64)
for img, a_key in zip(sr_8, aug_keys):
aug = AUG_DICT[a_key]
for a in reversed(aug):
img = a(img)
sr += img
sr /= len(sr_8)
sr = float_to_uint8(sr)
Image.fromarray(sr).save(sr_se_path)
print(f"merge sr bin save to {sr_se_path}")
return
if not cfg.self_ensemble:
dir_sr_bin = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR_bin", f"X{lr_scale}")
save_sr_dir = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR", f"X{lr_scale}")
if os.path.isdir(dir_sr_bin):
os.makedirs(save_sr_dir, exist_ok=True)
bin_patterns = [os.path.join(dir_sr_bin, f"*x{lr_scale}_0_0.bin")]
dataset = FolderImagePair(bin_patterns, reader=read_bin_as_hwc)
for i in range(len(dataset)):
img_key = dataset.get_key(i)
sr_path = os.path.join(save_sr_dir, f"{img_key}x{lr_scale}.png")
if os.path.isfile(sr_path):
continue
img_key, sr = dataset[i]
sr = float_to_uint8(sr)
Image.fromarray(sr).save(sr_path)
print(f"merge sr bin save to {sr_path}")
return
def get_hr_sr_dataset(cfg):
"""
make hr sr dataset
"""
dataset_path = cfg.data_path
dataset_type = "valid"
lr_scale = cfg.scale
dir_patterns = []
# get HR_PATH/*.png
dir_hr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_HR")
hr_pattern = os.path.join(dir_hr, "*.png")
dir_patterns.append(hr_pattern)
# get LR_PATH/X2/*x2.png, LR_PATH/X3/*x3.png, LR_PATH/X4/*x4.png
se = "_self_ensemble" if cfg.self_ensemble else ""
dir_sr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR" + se, f"X{lr_scale}")
if not os.path.isdir(dir_sr):
raise RuntimeError(f'{dir_sr} is not a dir for saving sr')
sr_pattern = os.path.join(dir_sr, f"*x{lr_scale}.png")
dir_patterns.append(sr_pattern)
# make dataset
dataset = FolderImagePair(dir_patterns)
return dataset
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_post_process():
"""
run post process
"""
print(config)
cfg = config
lr_scale = cfg.scale
init_env(cfg)
print("begin to run bin_to_png...")
bin_to_png(cfg)
print("bin_to_png finish")
dataset = get_hr_sr_dataset(cfg)
metrics = {
"psnr": PSNR(rgb_range=cfg.rgb_range, shave=6 + lr_scale),
}
total_step = len(dataset)
setw = len(str(total_step))
for i in range(len(dataset)):
_, hr, sr = dataset[i]
sr = unpadding(sr, hr.shape)
sr = img_to_tensor(sr)
hr = img_to_tensor(hr)
_ = [m.update(sr, hr) for m in metrics.values()]
result = {k: m.eval(sync=False) for k, m in metrics.items()}
print(f"[{i+1:>{setw}}/{total_step:>{setw}}] result = {result}", flush=True)
result = {k: m.eval(sync=False) for k, m in metrics.items()}
print(f"evaluation result = {result}", flush=True)
print("post_process success", flush=True)
if __name__ == "__main__":
run_post_process()

View File

@ -0,0 +1,98 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''pre process for 310 inference'''
import os
from PIL import Image
import numpy as np
from src.utils import modelarts_pre_process
from src.dataset import FolderImagePair, AUG_DICT
from model_utils.config import config
from model_utils.moxing_adapter import moxing_wrapper
MAX_HR_SIZE = 2040
def padding(img, target_shape):
h, w = target_shape[0], target_shape[1]
img_h, img_w, _ = img.shape
dh, dw = h - img_h, w - img_w
if dh < 0 or dw < 0:
raise RuntimeError(f"target_shape is bigger than img.shape, {target_shape} > {img.shape}")
if dh != 0 or dw != 0:
img = np.pad(img, ((0, dh), (0, dw), (0, 0)), "constant")
return img
def get_lr_dataset(cfg):
"""
get lr dataset
"""
dataset_path = cfg.data_path
lr_scale = cfg.scale
lr_type = cfg.lr_type
dataset_type = "valid"
self_ensemble = "_self_ensemble" if cfg.self_ensemble else ""
# get LR_PATH/X2/*x2.png, LR_PATH/X3/*x3.png, LR_PATH/X4/*x4.png
lrs_pattern = []
dir_lr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_LR_{lr_type}", f"X{lr_scale}")
lr_pattern = os.path.join(dir_lr, f"*x{lr_scale}.png")
lrs_pattern.append(lr_pattern)
save_dir = os.path.join(dataset_path, f"DIV2K_{dataset_type}_LR_{lr_type}_AUG{self_ensemble}", f"X{lr_scale}")
os.makedirs(save_dir, exist_ok=True)
save_format = os.path.join(save_dir, "{}" + f"x{lr_scale}" + "_{}.png")
# make dataset
dataset = FolderImagePair(lrs_pattern)
return dataset, save_format
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_pre_process():
"""
run pre process
"""
print(config)
cfg = config
aug_dict = AUG_DICT
if not cfg.self_ensemble:
aug_dict = {"0": AUG_DICT["0"]}
dataset, save_format = get_lr_dataset(cfg)
for i in range(len(dataset)):
img_key = dataset.get_key(i)
org_img = None
for a_key, aug in aug_dict.items():
save_path = save_format.format(img_key, a_key)
if os.path.isfile(save_path):
continue
if org_img is None:
_, lr = dataset[i]
target_shape = [MAX_HR_SIZE // cfg.scale, MAX_HR_SIZE // cfg.scale]
org_img = padding(lr, target_shape)
img = org_img.copy()
for a in aug:
img = a(img)
Image.fromarray(img).save(save_path)
print(f"[{i+1}/{len(dataset)}]\tsave {save_path}\tshape = {img.shape}", flush=True)
print("pre_process success", flush=True)
if __name__ == "__main__":
run_pre_process()

View File

@ -1,72 +0,0 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -f $PATH1 ]; then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]; then
echo "error: TRAIN_DATA_DIR=$PATH2 is not a directory"
exit 1
fi
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log
python train.py \
--batch_size 2 \
--lr 1e-4 \
--scale 2 \
--task_id 0 \
--dir_data $PATH2 \
--epochs 1000 \
--test_every 8000 \
--n_resblocks 32 \
--n_feats 256 \
--res_scale 0.1 \
--patch_size 48 > train.log 2>&1 &
cd ..
done

View File

@ -1,59 +0,0 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]; then
echo "Usage: sh run_standalone_train.sh [TRAIN_DATA_DIR]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
if [ ! -d $PATH1 ]; then
echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory"
exit 1
fi
if [ -d "train" ]; then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp -r ../src ./train
cd ./train || exit
env >env.log
python train.py \
--batch_size 16 \
--lr 1e-4 \
--scale 2 \
--task_id 0 \
--dir_data $PATH1 \
--epochs 1000 \
--test_every 1000 \
--n_resblocks 32 \
--n_feats 256 \
--res_scale 0.1 \
--patch_size 48 > train.log 2>&1 &

View File

@ -1,65 +0,0 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]; then
echo "Usage: sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
DATASET_TYPE=$3
if [ ! -d $PATH1 ]; then
echo "error: TEST_DATA_DIR=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]; then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
if [ -d "eval" ]; then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
echo "start evaluation ..."
python eval.py \
--dir_data=${PATH1} \
--test_only \
--ext img \
--ckpt_path=${PATH2} \
--task_id 0 \
--scale 2 \
--data_test=${DATASET_TYPE} \
--device_id 0 \
--n_resblocks 32 \
--n_feats 256 \
--res_scale 0.1 > log.txt 2>&1 &

View File

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 1 ]
then
echo "Usage: sh scripts/run_eval.sh [RANK_TABLE_FILE] --opt1 opt1_value --opt2 opt2_value ..."
exit 1
fi
if [ ! -f $1 ]
then
echo "error: RANK_TABLE_FILE=$1 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
PATH1=$(realpath $1)
export RANK_TABLE_FILE=$PATH1
echo "RANK_TABLE_FILE=${PATH1}"
export PYTHONPATH=$PWD:$PYTHONPATH
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./eval_parallel$i
mkdir ./eval_parallel$i
cp -r ./src ./eval_parallel$i
cp -r ./model_utils ./eval_parallel$i
cp -r ./*.yaml ./eval_parallel$i
cp ./eval.py ./eval_parallel$i
echo "start evaluation for rank $RANK_ID, device $DEVICE_ID"
cd ./eval_parallel$i ||exit
env > env.log
export args=${*:2}
python eval.py $args > eval.log 2>&1 &
cd ..
done

View File

@ -0,0 +1,151 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 3 || $# -gt 5 ]]; then
echo "Usage: bash scripts/run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [SCALE] [LOG_FILE] [DEVICE_ID]
DEVICE_ID is optional, it can be set by environment variable device_id, default: 0"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
data_path=$(get_real_path $2)
scale=$3
if [[ $scale -ne "2" && $scale -ne "3" && $scale -ne "4" ]]; then
echo "[SCALE] should be in [2,3,4]"
exit 1
fi
log_file="./run_infer.log"
if [ $# -gt 4 ]; then
log_file=$4
fi
log_file=$(get_real_path $log_file)
device_id=0
if [ $# == 5 ]; then
device_id=$5
fi
self_ensemble="True"
echo "***************** param *****************"
echo "mindir name: "$model
echo "dataset path: "$data_path
echo "scale: "$scale
echo "log file: "$log_file
echo "device id: "$device_id
echo "self_ensemble: "$self_ensemble
echo "***************** param *****************"
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
export PYTHONPATH=$PWD:$PYTHONPATH
function compile_app()
{
echo "begin to compile app..."
cd ./ascend310_infer || exit
bash build.sh >> $log_file 2>&1
cd -
echo "finshi compile app"
}
function preprocess()
{
echo "begin to preprocess..."
export DEVICE_ID=$device_id
export RANK_SIZE=1
python preprocess.py --data_path=$data_path --config_path=DIV2K_config.yaml --device_target=CPU --scale=$scale --self_ensemble=$self_ensemble >> $log_file 2>&1
echo "finshi preprocess"
}
function infer()
{
echo "begin to infer..."
if [ $self_ensemble == "True" ]; then
read_data_path=$data_path"/DIV2K_valid_LR_bicubic_AUG_self_ensemble/X"$scale
else
read_data_path=$data_path"/DIV2K_valid_LR_bicubic_AUG/X"$scale
fi
save_data_path=$data_path"/DIV2K_valid_SR_bin/X"$scale
if [ -d $save_data_path ]; then
rm -rf $save_data_path
fi
mkdir -p $save_data_path
./ascend310_infer/out/main --mindir_path=$model --dataset_path=$read_data_path --device_id=$device_id --save_dir=$save_data_path >> $log_file 2>&1
echo "finshi infer"
}
function postprocess()
{
echo "begin to postprocess..."
export DEVICE_ID=$device_id
export RANK_SIZE=1
python postprocess.py --data_path=$data_path --config_path=DIV2K_config.yaml --device_target=CPU --scale=$scale --self_ensemble=$self_ensemble >> $log_file 2>&1
echo "finshi postprocess"
}
echo "" > $log_file
echo "read the log command: "
echo " tail -f $log_file"
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed, check $log_file"
exit 1
fi
preprocess
if [ $? -ne 0 ]; then
echo "preprocess code failed, check $log_file"
exit 1
fi
infer
if [ $? -ne 0 ]; then
echo " execute inference failed, check $log_file"
exit 1
fi
postprocess
if [ $? -ne 0 ]; then
echo "postprocess failed, check $log_file"
exit 1
fi
cat $log_file | tail -n 3 | head -n 1

View File

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 1 ]
then
echo "Usage: sh scripts/run_train.sh [RANK_TABLE_FILE] --opt1 opt1_value --opt2 opt2_value ..."
exit 1
fi
if [ ! -f $1 ]
then
echo "error: RANK_TABLE_FILE=$1 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
PATH1=$(realpath $1)
export RANK_TABLE_FILE=$PATH1
echo "RANK_TABLE_FILE=${PATH1}"
export PYTHONPATH=$PWD:$PYTHONPATH
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp -r ./src ./train_parallel$i
cp -r ./model_utils ./train_parallel$i
cp -r ./*.yaml ./train_parallel$i
cp ./train.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
cd ./train_parallel$i ||exit
env > env.log
export args=${*:2}
python train.py $args > train.log 2>&1 &
cd ..
done

View File

@ -1,84 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""args parser"""
import argparse
parser = argparse.ArgumentParser(description='EDSR')
# Data specifications
parser.add_argument('--dir_data', type=str, default='/cache/data/',
help='dataset directory')
parser.add_argument('--data_train', type=str, default='DIV2K',
help='train dataset name')
parser.add_argument('--data_test', type=str, default='DIV2K',
help='test dataset name')
parser.add_argument('--data_range', type=str, default='1-800/801-900',
help='train/test data range')
parser.add_argument('--ext', type=str, default='sep',
help='dataset file extension')
parser.add_argument('--scale', type=str, default='4',
help='super resolution scale')
parser.add_argument('--patch_size', type=int, default=48,
help='input patch size')
parser.add_argument('--rgb_range', type=int, default=255,
help='maximum value of RGB')
parser.add_argument('--n_colors', type=int, default=3,
help='number of color channels to use')
parser.add_argument('--no_augment', action='store_true',
help='do not use data augmentation')
# Model specifications
parser.add_argument('--model', default='EDSR',
help='model name')
parser.add_argument('--n_resblocks', type=int, default=32,
help='number of residual blocks')
parser.add_argument('--n_feats', type=int, default=256,
help='number of feature maps')
parser.add_argument('--res_scale', type=float, default=0.1,
help='residual scaling')
# Training specifications
parser.add_argument('--test_every', type=int, default=8000,
help='do test per every N batches')
parser.add_argument('--epochs', type=int, default=1000,
help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=2,
help='input batch size for training')
parser.add_argument('--test_only', action='store_true',
help='set this option to test the model')
# Optimization specifications
parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate')
parser.add_argument('--loss_scale', type=float, default=1024.0,
help='init loss scale')
# ckpt specifications
parser.add_argument('--ckpt_save_path', type=str, default='./ckpt/',
help='path to save ckpt')
parser.add_argument('--ckpt_save_interval', type=int, default=10,
help='save ckpt frequency, unit is epoch')
parser.add_argument('--ckpt_save_max', type=int, default=5,
help='max number of saved ckpt')
parser.add_argument('--ckpt_path', type=str, default='',
help='path of saved ckpt')
# alltask
parser.add_argument('--task_id', type=int, default=0)
args, unparsed = parser.parse_known_args()
args.scale = [int(x) for x in args.scale.split("+")]
args.data_train = args.data_train.split('+')
args.data_test = args.data_test.split('+')
if args.epochs == 0:
args.epochs = 1e4
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False

View File

@ -1,96 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""common"""
import math
import numpy as np
import mindspore
import mindspore.nn as nn
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
pad_mode='pad',
padding=(kernel_size//2), has_bias=bias)
class MeanShift(mindspore.nn.Conv2d):
"""MeanShift"""
def __init__(
self, rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1, dtype=mindspore.float32):
std = mindspore.Tensor(rgb_std, dtype)
weight = mindspore.Tensor(np.eye(3), dtype).reshape(
3, 3, 1, 1) / std.reshape(3, 1, 1, 1)
bias = sign * rgb_range * mindspore.Tensor(rgb_mean, dtype) / std
super(MeanShift, self).__init__(3, 3, kernel_size=1,
has_bias=True, weight_init=weight, bias_init=bias)
for p in self.get_parameters():
p.requires_grad = False
class ResBlock(nn.Cell):
"""ResBlock"""
def __init__(
self, conv, n_feats, kernel_size,
bias=True, act=nn.ReLU(), res_scale=1):
super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
if i == 0:
m.append(act)
self.body = nn.SequentialCell(m)
self.res_scale = res_scale
self.mul = mindspore.ops.Mul()
def construct(self, x):
res = self.body(x)
res = self.mul(res, self.res_scale)
res += x
return res
class PixelShuffle(nn.Cell):
"""PixelShuffle"""
def __init__(self, upscale_factor):
super(PixelShuffle, self).__init__()
self.DepthToSpace = mindspore.ops.DepthToSpace(upscale_factor)
def construct(self, x):
return self.DepthToSpace(x)
def Upsampler(conv, scale, n_feats, bias=True):
"""Upsampler"""
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feats, 4 * n_feats, 3, bias))
m.append(PixelShuffle(2))
elif scale == 3:
m.append(conv(n_feats, 9 * n_feats, 3, bias))
m.append(PixelShuffle(3))
else:
raise NotImplementedError
return m

View File

@ -1,75 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""common"""
import random
import numpy as np
def get_patch(*args, patch_size=96, scale=2, input_large=False):
"""common"""
ih, iw = args[0].shape[:2]
tp = patch_size
ip = tp // scale
ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)
if not input_large:
tx, ty = scale * ix, scale * iy
else:
tx, ty = ix, iy
ret = [args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]]
return ret
def set_channel(*args, n_channels=3):
"""common"""
def _set_channel(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
c = img.shape[2]
if n_channels == 3 and c == 1:
img = np.concatenate([img] * n_channels, 2)
return img[:, :, :n_channels]
return [_set_channel(a) for a in args]
def np2Tensor(*args, rgb_range=255):
def _np2Tensor(img):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
input_data = np_transpose.astype(np.float32)
output = input_data * (rgb_range / 255)
return output
return [_np2Tensor(a) for a in args]
def augment(*args, hflip=True, rot=True):
"""common"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
"""common"""
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(a) for a in args]
def search(root, target="JPEG"):
"""srdata"""
item_list = []
items = os.listdir(root)
for item in items:
path = os.path.join(root, item)
if os.path.isdir(path):
item_list.extend(search(path, target))
elif path.split('/')[-1].startswith(target):
item_list.append(path)
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
item_list.append(path)
return item_list

View File

@ -1,43 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""div2k"""
import os
from src.data.srdata import SRData
class DIV2K(SRData):
"""DIV2K"""
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
data_range = [r.split('-') for r in args.data_range.split('/')]
if train:
data_range = data_range[0]
else:
if args.test_only and len(data_range) == 1:
data_range = data_range[0]
else:
data_range = data_range[1]
self.begin, self.end = list(map(int, data_range))
super(DIV2K, self).__init__(args, name=name, train=train, benchmark=benchmark)
self.dir_hr = None
self.dir_lr = None
def _scan(self):
names_hr, names_lr = super(DIV2K, self)._scan()
names_hr = names_hr[self.begin - 1:self.end]
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
return names_hr, names_lr
def _set_filesystem(self, dir_data):
super(DIV2K, self)._set_filesystem(dir_data)
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')

View File

@ -1,212 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""""srdata"""
import os
import glob
import random
import pickle
import imageio
from src.data import common
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class SRData:
"""srdata"""
def __init__(self, args, name='', train=True, benchmark=False):
self.args = args
self.name = name
self.train = train
self.split = 'train' if train else 'test'
self.do_eval = True
self.benchmark = benchmark
self.input_large = (args.model == 'VDSR')
self.scale = args.scale
self.idx_scale = 0
self._set_filesystem(args.dir_data)
self._set_img(args)
if train:
self._repeat(args)
def _set_img(self, args):
"""srdata"""
if args.ext.find('img') < 0:
path_bin = os.path.join(self.apath, 'bin')
os.makedirs(path_bin, exist_ok=True)
list_hr, list_lr = self._scan()
if args.ext.find('img') >= 0 or self.benchmark:
self.images_hr, self.images_lr = list_hr, list_lr
elif args.ext.find('sep') >= 0:
os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True)
for s in self.scale:
if s == 1:
os.makedirs(os.path.join(self.dir_hr), exist_ok=True)
else:
os.makedirs(
os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True)
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
for h in list_hr:
b = h.replace(self.apath, path_bin)
b = b.replace(self.ext[0], '.pt')
self.images_hr.append(b)
self._check_and_load(args.ext, h, b, verbose=True)
for i, ll in enumerate(list_lr):
for l in ll:
b = l.replace(self.apath, path_bin)
b = b.replace(self.ext[1], '.pt')
self.images_lr[i].append(b)
self._check_and_load(args.ext, l, b, verbose=True)
def _repeat(self, args):
"""srdata"""
n_patches = args.batch_size * args.test_every
n_images = len(args.data_train) * len(self.images_hr)
if n_images == 0:
self.repeat = 0
else:
self.repeat = max(n_patches // n_images, 1)
def _scan(self):
"""srdata"""
names_hr = sorted(
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])))
names_lr = [[] for _ in self.scale]
for f in names_hr:
filename, _ = os.path.splitext(os.path.basename(f))
for si, s in enumerate(self.scale):
if s != 1:
scale = s
names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \
.format(s, filename, scale, self.ext[1])))
for si, s in enumerate(self.scale):
if s == 1:
names_lr[si] = names_hr
return names_hr, names_lr
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, self.name[0])
self.dir_hr = os.path.join(self.apath, 'HR')
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
self.ext = ('.png', '.png')
def _check_and_load(self, ext, img, f, verbose=True):
if not os.path.isfile(f) or ext.find('reset') >= 0:
if verbose:
print('Making a binary: {}'.format(f))
with open(f, 'wb') as _f:
pickle.dump(imageio.imread(img), _f)
def __getitem__(self, idx):
lr, hr, _ = self._load_file(idx)
pair = self.get_patch(lr, hr)
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
return pair_t[0], pair_t[1]
def __len__(self):
if self.train:
return len(self.images_hr) * self.repeat
return len(self.images_hr)
def _get_index(self, idx):
if self.train:
return idx % len(self.images_hr)
return idx
def _load_file_deblur(self, idx, train=True):
"""srdata"""
idx = self._get_index(idx)
if train:
f_hr = self.images_hr[idx]
f_lr = self.images_lr[idx]
else:
f_hr = self.deblur_hr_test[idx]
f_lr = self.deblur_lr_test[idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
filename = f_hr[-27:-17] + filename
hr = imageio.imread(f_hr)
lr = imageio.imread(f_lr)
return lr, hr, filename
def _load_file_hr(self, idx):
"""srdata"""
idx = self._get_index(idx)
f_hr = self.images_hr[idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
if self.args.ext == 'img' or self.benchmark:
hr = imageio.imread(f_hr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f:
hr = pickle.load(_f)
return hr, filename
def _load_rain_test(self, idx):
f_hr = self.derain_hr_test[idx]
f_lr = self.derain_lr_test[idx]
filename, _ = os.path.splitext(os.path.basename(f_lr))
norain = imageio.imread(f_hr)
rain = imageio.imread(f_lr)
return norain, rain, filename
def _load_file(self, idx):
"""srdata"""
idx = self._get_index(idx)
f_hr = self.images_hr[idx]
f_lr = self.images_lr[self.idx_scale][idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
if self.args.ext == 'img' or self.benchmark:
hr = imageio.imread(f_hr)
lr = imageio.imread(f_lr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f:
hr = pickle.load(_f)
with open(f_lr, 'rb') as _f:
lr = pickle.load(_f)
return lr, hr, filename
def get_patch_hr(self, hr):
"""srdata"""
if self.train:
hr = self.get_patch_img_hr(hr, patch_size=self.args.patch_size, scale=1)
return hr
def get_patch_img_hr(self, img, patch_size=96, scale=2):
"""srdata"""
ih, iw = img.shape[:2]
tp = patch_size
ip = tp // scale
ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)
ret = img[iy:iy + ip, ix:ix + ip, :]
return ret
def get_patch(self, lr, hr):
"""srdata"""
scale = self.scale[self.idx_scale]
if self.train:
lr, hr = common.get_patch(
lr, hr,
patch_size=self.args.patch_size * scale,
scale=scale)
if not self.args.no_augment:
lr, hr = common.augment(lr, hr)
else:
ih, iw = lr.shape[:2]
hr = hr[0:ih * scale, 0:iw * scale]
return lr, hr
def set_scale(self, idx_scale):
if not self.input_large:
self.idx_scale = idx_scale
else:
self.idx_scale = random.randint(0, len(self.scale) - 1)

View File

@ -0,0 +1,321 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import os
import glob
import re
from functools import reduce
import random
from PIL import Image
import numpy as np
import mindspore.dataset as ds
def get_rank_info():
"""
get rank size and rank id
"""
from model_utils.moxing_adapter import get_rank_id, get_device_num
return get_device_num(), get_rank_id()
class FolderImagePair:
"""
get image pair
dir_patterns(list): a list of image path patterns. such as ["/LR/*.jpg", "/HR/*.png"...]
the file key is matched chars from * and ?
reader(object/func): a method to read image by path.
"""
def __init__(self, dir_patterns, reader=None):
self.dir_patterns = dir_patterns
self.reader = reader
self.pair_keys, self.image_pairs = self.scan_pair(self.dir_patterns)
@classmethod
def scan_pair(cls, dir_patterns):
"""
scan pair
"""
images = []
for _dir in dir_patterns:
imgs = glob.glob(_dir)
_dir = os.path.basename(_dir)
pat = _dir.replace("*", "(.*)").replace("?", "(.?)")
pat = re.compile(pat, re.I | re.M)
keys = [re.findall(pat, os.path.basename(p))[0] for p in imgs]
images.append({k: v for k, v in zip(keys, imgs)})
same_keys = reduce(lambda x, y: set(x) & set(y), images)
same_keys = sorted(same_keys)
image_pairs = [[d[k] for d in images] for k in same_keys]
same_keys = [x if isinstance(x, str) else "_".join(x) for x in same_keys]
return same_keys, image_pairs
def get_key(self, idx):
return self.pair_keys[idx]
def __getitem__(self, idx):
if self.reader is None:
images = [Image.open(p) for p in self.image_pairs[idx]]
images = [img.convert('RGB') for img in images]
images = [np.array(img) for img in images]
else:
images = [self.reader(p) for p in self.image_pairs[idx]]
pair_key = self.pair_keys[idx]
return (pair_key, *images)
def __len__(self):
return len(self.pair_keys)
class LrHrImages(FolderImagePair):
"""
make LrHrImages dataset
"""
def __init__(self, lr_pattern, hr_pattern, reader=None):
self.hr_pattern = hr_pattern
self.lr_pattern = lr_pattern
self.dir_patterns = []
if isinstance(self.lr_pattern, str):
self.is_multi_lr = False
self.dir_patterns.append(self.lr_pattern)
elif len(lr_pattern) == 1:
self.is_multi_lr = False
self.dir_patterns.append(self.lr_pattern[0])
else:
self.is_multi_lr = True
self.dir_patterns.extend(self.lr_pattern)
self.dir_patterns.append(self.hr_pattern)
super(LrHrImages, self).__init__(self.dir_patterns, reader=reader)
def __getitem__(self, idx):
_, *images = super(LrHrImages, self).__getitem__(idx)
return images
class _BasePatchCutter:
"""
cut patch from images
patch_size(int): patch size, input images should be bigger than patch_size.
lr_scale(int/list): lr scales for input images. Choice from [1,2,3,4, or their combination]
"""
def __init__(self, patch_size, lr_scale):
self.patch_size = patch_size
self.multi_lr_scale = lr_scale
if isinstance(lr_scale, int):
self.multi_lr_scale = [lr_scale]
else:
self.multi_lr_scale = [*lr_scale]
self.max_lr_scale_idx = self.multi_lr_scale.index(max(self.multi_lr_scale))
self.max_lr_scale = self.multi_lr_scale[self.max_lr_scale_idx]
def get_tx_ty(self, target_height, target_weight, target_patch_size):
raise NotImplementedError()
def __call__(self, *images):
target_img = images[self.max_lr_scale_idx]
tp = self.patch_size // self.max_lr_scale
th, tw, _ = target_img.shape
tx, ty = self.get_tx_ty(th, tw, tp)
patch_images = []
for _, (img, lr_scale) in enumerate(zip(images, self.multi_lr_scale)):
x = tx * self.max_lr_scale // lr_scale
y = ty * self.max_lr_scale // lr_scale
p = tp * self.max_lr_scale // lr_scale
patch_images.append(img[y:(y + p), x:(x + p), :])
return tuple(patch_images)
class RandomPatchCutter(_BasePatchCutter):
def __init__(self, patch_size, lr_scale):
super(RandomPatchCutter, self).__init__(patch_size=patch_size, lr_scale=lr_scale)
def get_tx_ty(self, target_height, target_weight, target_patch_size):
target_x = random.randrange(0, target_weight - target_patch_size + 1)
target_y = random.randrange(0, target_height - target_patch_size + 1)
return target_x, target_y
class CentrePatchCutter(_BasePatchCutter):
def __init__(self, patch_size, lr_scale):
super(CentrePatchCutter, self).__init__(patch_size=patch_size, lr_scale=lr_scale)
def get_tx_ty(self, target_height, target_weight, target_patch_size):
target_x = (target_weight - target_patch_size) // 2
target_y = (target_height - target_patch_size) // 2
return target_x, target_y
def hflip(img):
return img[:, ::-1, :]
def vflip(img):
return img[::-1, :, :]
def trnsp(img):
return img.transpose(1, 0, 2)
AUG_LIST = [
[],
[trnsp],
[vflip],
[vflip, trnsp],
[hflip],
[hflip, trnsp],
[hflip, vflip],
[hflip, vflip, trnsp],
]
AUG_DICT = {
"0": [],
"t": [trnsp],
"v": [vflip],
"vt": [vflip, trnsp],
"h": [hflip],
"ht": [hflip, trnsp],
"hv": [hflip, vflip],
"hvt": [hflip, vflip, trnsp],
}
def flip_and_rotate(*images):
aug = random.choice(AUG_LIST)
res = []
for img in images:
for a in aug:
img = a(img)
res.append(img)
return tuple(res)
def hwc2chw(*images):
res = [i.transpose(2, 0, 1) for i in images]
return tuple(res)
def uint8_to_float32(*images):
res = [(i.astype(np.float32) if i.dtype == np.uint8 else i) for i in images]
return tuple(res)
def create_dataset_DIV2K(config, dataset_type="train", num_parallel_workers=10, shuffle=True):
"""
create a train or eval DIV2K dataset
Args:
config(dict):
dataset_path(string): the path of dataset.
scale(int/list): lr scale, read data ordered by it, choices=(2,3,4,[2],[3],[4],[2,3],[2,4],[3,4],[2,3,4])
lr_type(string): lr images type, choices=("bicubic", "unknown"), Default "bicubic"
batch_size(int): the batch size of dataset. (train prarm), Default 1
patch_size(int): train data size. (train param), Default -1
epoch_size(int): times to repeat dataset for dataset_sink_mode, Default None
dataset_type(string): choices=("train", "valid", "test"), Default "train"
num_parallel_workers(int): num-workers to read data, Default 10
shuffle(bool): shuffle dataset. Default: True
Returns:
dataset
"""
dataset_path = config["dataset_path"]
lr_scale = config["scale"]
lr_type = config.get("lr_type", "bicubic")
batch_size = config.get("batch_size", 1)
patch_size = config.get("patch_size", -1)
epoch_size = config.get("epoch_size", None)
# for multi lr scale, such as [2,3,4]
if isinstance(lr_scale, int):
multi_lr_scale = [lr_scale]
else:
multi_lr_scale = lr_scale
# get HR_PATH/*.png
dir_hr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_HR")
hr_pattern = os.path.join(dir_hr, "*.png")
# get LR_PATH/X2/*x2.png, LR_PATH/X3/*x3.png, LR_PATH/X4/*x4.png
column_names = []
lrs_pattern = []
for lr_scale in multi_lr_scale:
dir_lr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_LR_{lr_type}", f"X{lr_scale}")
lr_pattern = os.path.join(dir_lr, f"*x{lr_scale}.png")
lrs_pattern.append(lr_pattern)
column_names.append(f"lrx{lr_scale}")
column_names.append("hr") # ["lrx2","lrx3","lrx4",..., "hr"]
# make dataset
dataset = LrHrImages(lr_pattern=lrs_pattern, hr_pattern=hr_pattern)
# make mindspore dataset
device_num, rank_id = get_rank_info()
if device_num == 1 or device_num is None:
generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names,
num_parallel_workers=num_parallel_workers,
shuffle=shuffle and dataset_type == "train")
elif dataset_type == "train":
generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names,
num_parallel_workers=num_parallel_workers,
shuffle=shuffle and dataset_type == "train",
num_shards=device_num, shard_id=rank_id)
else:
sampler = ds.DistributedSampler(num_shards=device_num, shard_id=rank_id, shuffle=False, offset=0)
generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names,
num_parallel_workers=num_parallel_workers,
sampler=sampler)
# define map operations
if dataset_type == "train":
transform_img = [
RandomPatchCutter(patch_size, multi_lr_scale + [1]),
flip_and_rotate,
hwc2chw,
uint8_to_float32,
]
elif patch_size > 0:
transform_img = [
CentrePatchCutter(patch_size, multi_lr_scale + [1]),
hwc2chw,
uint8_to_float32,
]
else:
transform_img = [
hwc2chw,
uint8_to_float32,
]
# pre-process hr lr
generator_dataset = generator_dataset.map(input_columns=column_names,
output_columns=column_names,
column_order=column_names,
operations=transform_img)
# apply batch operations
generator_dataset = generator_dataset.batch(batch_size, drop_remainder=False)
# apply repeat operations
if dataset_type == "train" and epoch_size is not None and epoch_size != 1:
generator_dataset = generator_dataset.repeat(epoch_size)
return generator_dataset

View File

@ -0,0 +1,205 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""EDSR"""
import numpy as np
from mindspore import Parameter
from mindspore import nn, ops
from mindspore.common.initializer import TruncatedNormal
class RgbNormal(nn.Cell):
"""
"MeanShift" in EDSR paper pytorch-code:
https://github.com/sanghyun-son/EDSR-PyTorch/blob/master/src/model/common.py
it is not unreasonable in the case below
if std != 1 and sign = -1: y = x * rgb_std - rgb_range * rgb_mean
if std != 1 and sign = 1: y = x * rgb_std + rgb_range * rgb_mean
they are not inverse operation for each other!
so use "RgbNormal" instead, it runs as below:
if inverse = False: y = (x / rgb_range - mean) / std
if inverse = True : x = (y * std + mean) * rgb_range
"""
def __init__(self, rgb_range, rgb_mean, rgb_std, inverse=False):
super(RgbNormal, self).__init__()
self.rgb_range = rgb_range
self.rgb_mean = rgb_mean
self.rgb_std = rgb_std
self.inverse = inverse
std = np.array(self.rgb_std, dtype=np.float32)
mean = np.array(self.rgb_mean, dtype=np.float32)
if not inverse:
# y: (x / rgb_range - mean) / std <=> x * (1.0 / rgb_range / std) + (-mean) / std
weight = (1.0 / self.rgb_range / std).reshape((1, -1, 1, 1))
bias = (-mean / std).reshape((1, -1, 1, 1))
else:
# x: (y * std + mean) * rgb_range <=> y * (std * rgb_range) + mean * rgb_range
weight = (self.rgb_range * std).reshape((1, -1, 1, 1))
bias = (mean * rgb_range).reshape((1, -1, 1, 1))
self.weight = Parameter(weight, requires_grad=False)
self.bias = Parameter(bias, requires_grad=False)
def construct(self, x):
return x * self.weight + self.bias
def extend_repr(self):
s = 'rgb_range={}, rgb_mean={}, rgb_std={}, inverse = {}' \
.format(
self.rgb_range,
self.rgb_mean,
self.rgb_std,
self.inverse,
)
return s
def make_conv2d(in_channels, out_channels, kernel_size, has_bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
pad_mode="same", has_bias=has_bias, weight_init=TruncatedNormal(0.02))
class ResBlock(nn.Cell):
"""
Resnet Block
"""
def __init__(
self, in_channels, out_channels, kernel_size=1, has_bias=True, res_scale=1):
super(ResBlock, self).__init__()
self.conv1 = make_conv2d(in_channels, in_channels, kernel_size, has_bias)
self.relu = nn.ReLU()
self.conv2 = make_conv2d(in_channels, out_channels, kernel_size, has_bias)
self.res_scale = res_scale
def construct(self, x):
res = self.conv1(x)
res = self.relu(res)
res = self.conv2(res)
res = res * self.res_scale
x = x + res
return x
class PixelShuffle(nn.Cell):
"""
PixelShuffle using ops.DepthToSpace
"""
def __init__(self, upscale_factor):
super(PixelShuffle, self).__init__()
self.upscale_factor = upscale_factor
self.upper = ops.DepthToSpace(self.upscale_factor)
def construct(self, x):
return self.upper(x)
def extend_repr(self):
return 'upscale_factor={}'.format(self.upscale_factor)
def UpsamplerBlockList(upscale_factor, n_feats, has_bias=True):
"""
make Upsampler Block List
"""
if upscale_factor == 1:
return []
allow_sub_upscale_factor = [2, 3, None]
for sub in allow_sub_upscale_factor:
if sub is None:
raise NotImplementedError(
f"Only support \"scales\" that can be divisibled by {allow_sub_upscale_factor[:-1]}")
if upscale_factor % sub == 0:
break
sub_block_list = [
make_conv2d(n_feats, sub*sub*n_feats, 3, has_bias),
PixelShuffle(sub),
]
return sub_block_list + UpsamplerBlockList(upscale_factor // sub, n_feats, has_bias)
class Upsampler(nn.Cell):
def __init__(self, scale, n_feats, has_bias=True):
super(Upsampler, self).__init__()
up = UpsamplerBlockList(scale, n_feats, has_bias)
self.up = nn.SequentialCell(*up)
def construct(self, x):
x = self.up(x)
return x
class EDSR(nn.Cell):
"""
EDSR network
"""
def __init__(self, scale, n_feats, kernel_size, n_resblocks,
n_colors=3,
res_scale=0.1,
rgb_range=255,
rgb_mean=(0.0, 0.0, 0.0),
rgb_std=(1.0, 1.0, 1.0)):
super(EDSR, self).__init__()
self.norm = RgbNormal(rgb_range, rgb_mean, rgb_std, inverse=False)
self.de_norm = RgbNormal(rgb_range, rgb_mean, rgb_std, inverse=True)
m_head = [make_conv2d(n_colors, n_feats, kernel_size)]
m_body = [
ResBlock(n_feats, n_feats, kernel_size, res_scale=res_scale)
for _ in range(n_resblocks)
]
m_body.append(make_conv2d(n_feats, n_feats, kernel_size))
m_tail = [
Upsampler(scale, n_feats),
make_conv2d(n_feats, n_colors, kernel_size)
]
self.head = nn.SequentialCell(m_head)
self.body = nn.SequentialCell(m_body)
self.tail = nn.SequentialCell(m_tail)
def construct(self, x):
x = self.norm(x)
x = self.head(x)
x = x + self.body(x)
x = self.tail(x)
x = self.de_norm(x)
return x
def load_pre_trained_param_dict(self, new_param_dict, strict=True):
"""
load pre_trained param dict from edsr_x2
"""
own_param = self.parameters_dict()
for name, new_param in new_param_dict.items():
if len(name) >= 4 and name[:4] == "net.":
name = name[4:]
if name in own_param:
if isinstance(new_param, Parameter):
param = own_param[name]
if tuple(param.data.shape) == tuple(new_param.data.shape):
param.set_data(type(param.data)(new_param.data))
elif name.find('tail') == -1:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_param[name].shape, new_param.shape))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in parameters_dict()'
.format(name))

View File

@ -0,0 +1,330 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Metric for evaluation."""
import os
import math
from PIL import Image
import numpy as np
from mindspore import nn, Tensor, ops
from mindspore import dtype as mstype
from mindspore.ops.operations.comm_ops import ReduceOp
try:
from model_utils.device_adapter import get_rank_id, get_device_num
except ImportError:
get_rank_id = None
get_device_num = None
finally:
pass
class SelfEnsembleWrapperNumpy:
"""
SelfEnsembleWrapperNumpy using numpy
"""
def __init__(self, net):
super(SelfEnsembleWrapperNumpy, self).__init__()
self.net = net
def hflip(self, x):
return x[:, :, :, ::-1]
def vflip(self, x):
return x[:, :, ::-1, :]
def trnsps(self, x):
return x.transpose(0, 1, 3, 2)
def aug_x8(self, x):
"""
do x8 augments for input image
"""
# hflip
hx = self.hflip(x)
# vflip
vx = self.vflip(x)
vhx = self.vflip(hx)
# trnsps
tx = self.trnsps(x)
thx = self.trnsps(hx)
tvx = self.trnsps(vx)
tvhx = self.trnsps(vhx)
return x, hx, vx, vhx, tx, thx, tvx, tvhx
def aug_x8_reverse(self, x, hx, vx, vhx, tx, thx, tvx, tvhx):
"""
undo x8 augments for input images
"""
# trnsps
tvhx = self.trnsps(tvhx)
tvx = self.trnsps(tvx)
thx = self.trnsps(thx)
tx = self.trnsps(tx)
# vflip
tvhx = self.vflip(tvhx)
tvx = self.vflip(tvx)
vhx = self.vflip(vhx)
vx = self.vflip(vx)
# hflip
tvhx = self.hflip(tvhx)
thx = self.hflip(thx)
vhx = self.hflip(vhx)
hx = self.hflip(hx)
return x, hx, vx, vhx, tx, thx, tvx, tvhx
def to_numpy(self, *inputs):
if inputs:
return None
if len(inputs) == 1:
return inputs[0].asnumpy()
return [x.asnumpy() for x in inputs]
def to_tensor(self, *inputs):
if inputs:
return None
if len(inputs) == 1:
return Tensor(inputs[0])
return [Tensor(x) for x in inputs]
def set_train(self, mode=True):
self.net.set_train(mode)
return self
def __call__(self, x):
x = self.to_numpy(x)
x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8(x)
x0, x1, x2, x3, x4, x5, x6, x7 = self.to_tensor(x0, x1, x2, x3, x4, x5, x6, x7)
x0 = self.net(x0)
x1 = self.net(x1)
x2 = self.net(x2)
x3 = self.net(x3)
x4 = self.net(x4)
x5 = self.net(x5)
x6 = self.net(x6)
x7 = self.net(x7)
x0, x1, x2, x3, x4, x5, x6, x7 = self.to_numpy(x0, x1, x2, x3, x4, x5, x6, x7)
x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8_reverse(x0, x1, x2, x3, x4, x5, x6, x7)
x0, x1, x2, x3, x4, x5, x6, x7 = self.to_tensor(x0, x1, x2, x3, x4, x5, x6, x7)
return (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8
class SelfEnsembleWrapper(nn.Cell):
"""
because of [::-1] operator error, use "SelfEnsembleWrapperNumpy" instead
"""
def __init__(self, net):
super(SelfEnsembleWrapper, self).__init__()
self.net = net
def hflip(self, x):
raise NotImplementedError("https://gitee.com/mindspore/mindspore/issues/I41ONQ?from=project-issue")
def vflip(self, x):
raise NotImplementedError("https://gitee.com/mindspore/mindspore/issues/I41ONQ?from=project-issue")
def trnsps(self, x):
return x.transpose(0, 1, 3, 2)
def aug_x8(self, x):
"""
do x8 augments for input image
"""
# hflip
hx = self.hflip(x)
# vflip
vx = self.vflip(x)
vhx = self.vflip(hx)
# trnsps
tx = self.trnsps(x)
thx = self.trnsps(hx)
tvx = self.trnsps(vx)
tvhx = self.trnsps(vhx)
return x, hx, vx, vhx, tx, thx, tvx, tvhx
def aug_x8_reverse(self, x, hx, vx, vhx, tx, thx, tvx, tvhx):
"""
undo x8 augments for input images
"""
# trnsps
tvhx = self.trnsps(tvhx)
tvx = self.trnsps(tvx)
thx = self.trnsps(thx)
tx = self.trnsps(tx)
# vflip
tvhx = self.vflip(tvhx)
tvx = self.vflip(tvx)
vhx = self.vflip(vhx)
vx = self.vflip(vx)
# hflip
tvhx = self.hflip(tvhx)
thx = self.hflip(thx)
vhx = self.hflip(vhx)
hx = self.hflip(hx)
return x, hx, vx, vhx, tx, thx, tvx, tvhx
def construct(self, x):
"""
do x8 aug, run network, undo x8 aug, calculate mean for 8 output
"""
x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8(x)
x0 = self.net(x0)
x1 = self.net(x1)
x2 = self.net(x2)
x3 = self.net(x3)
x4 = self.net(x4)
x5 = self.net(x5)
x6 = self.net(x6)
x7 = self.net(x7)
x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8_reverse(x0, x1, x2, x3, x4, x5, x6, x7)
return (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8
class Quantizer(nn.Cell):
"""
clip by [0.0, 255.0], rount to int
"""
def __init__(self, _min=0.0, _max=255.0):
super(Quantizer, self).__init__()
self.round = ops.Round()
self._min = _min
self._max = _max
def construct(self, x):
x = ops.clip_by_value(x, self._min, self._max)
x = self.round(x)
return x
class TensorSyncer(nn.Cell):
"""
sync metric values from all mindspore-processes
"""
def __init__(self, _type="sum"):
super(TensorSyncer, self).__init__()
self._type = _type.lower()
if self._type == "sum":
self.ops = ops.AllReduce(ReduceOp.SUM)
elif self._type == "gather":
self.ops = ops.AllGather()
else:
raise ValueError(f"TensorSyncer._type == {self._type} is not support")
def construct(self, x):
return self.ops(x)
class _DistMetric(nn.Metric):
"""
gather data from all rank while eval(True)
_type(str): choice from ["avg", "sum"].
"""
def __init__(self, _type):
super(_DistMetric, self).__init__()
self._type = _type.lower()
self.all_reduce_sum = None
if get_device_num is not None and get_device_num() > 1:
self.all_reduce_sum = TensorSyncer(_type="sum")
self.clear()
def _accumulate(self, value):
if isinstance(value, (list, tuple)):
self._acc_value += sum(value)
self._count += len(value)
else:
self._acc_value += value
self._count += 1
def clear(self):
self._acc_value = 0.0
self._count = 0
def eval(self, sync=True):
"""
sync: True, return metric value merged from all mindspore-processes
sync: False, return metric value in this single mindspore-processes
"""
if self._count == 0:
raise RuntimeError('self._count == 0')
if self.sum is not None and sync:
data = Tensor([self._acc_value, self._count], mstype.float32)
data = self.all_reduce_sum(data)
acc_value, count = self._convert_data(data).tolist()
else:
acc_value, count = self._acc_value, self._count
if self._type == "avg":
return acc_value / count
if self._type == "sum":
return acc_value
raise RuntimeError(f"_DistMetric._type={self._type} is not support")
class PSNR(_DistMetric):
"""
Define PSNR metric for SR network.
"""
def __init__(self, rgb_range, shave):
super(PSNR, self).__init__(_type="avg")
self.shave = shave
self.rgb_range = rgb_range
self.quantize = Quantizer(0.0, 255.0)
def update(self, *inputs):
"""
update psnr
"""
if len(inputs) != 2:
raise ValueError('PSNR need 2 inputs (sr, hr), but got {}'.format(len(inputs)))
sr, hr = inputs
sr = self.quantize(sr)
diff = (sr - hr) / self.rgb_range
valid = diff
if self.shave is not None and self.shave != 0:
valid = valid[..., self.shave:(-self.shave), self.shave:(-self.shave)]
mse_list = (valid ** 2).mean(axis=(1, 2, 3))
mse_list = self._convert_data(mse_list).tolist()
psnr_list = [float(1e32) if mse == 0 else(- 10.0 * math.log10(mse)) for mse in mse_list]
self._accumulate(psnr_list)
class SaveSrHr(_DistMetric):
"""
help to save sr and hr
"""
def __init__(self, save_dir):
super(SaveSrHr, self).__init__(_type="sum")
self.save_dir = save_dir
self.quantize = Quantizer(0.0, 255.0)
self.rank_id = 0 if get_rank_id is None else get_rank_id()
self.device_num = 1 if get_device_num is None else get_device_num()
def update(self, *inputs):
"""
update images to save
"""
if len(inputs) != 2:
raise ValueError('SaveSrHr need 2 inputs (sr, hr), but got {}'.format(len(inputs)))
sr, hr = inputs
sr = self.quantize(sr)
sr = self._convert_data(sr).astype(np.uint8)
hr = self._convert_data(hr).astype(np.uint8)
for s, h in zip(sr.transpose(0, 2, 3, 1), hr.transpose(0, 2, 3, 1)):
idx = self._count * self.device_num + self.rank_id
sr_path = os.path.join(self.save_dir, f"{idx:0>4}_sr.png")
Image.fromarray(s).save(sr_path)
hr_path = os.path.join(self.save_dir, f"{idx:0>4}_hr.png")
Image.fromarray(h).save(hr_path)
self._accumulate(1)

View File

@ -1,102 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metrics"""
import math
import numpy as np
import cv2
def quantize(img, rgb_range):
"""quantize image range to 0-255"""
pixel_range = 255 / rgb_range
img = np.multiply(img, pixel_range)
img = np.clip(img, 0, 255)
img = np.round(img) / pixel_range
return img
def calc_psnr(sr, hr, scale, rgb_range):
"""calculate psnr"""
hr = np.float32(hr)
sr = np.float32(sr)
diff = (sr - hr) / rgb_range
gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256
diff = np.multiply(diff, gray_coeffs).sum(1)
if hr.size == 1:
return 0
if scale != 1:
shave = scale
else:
shave = scale + 6
if scale == 1:
valid = diff
else:
valid = diff[..., shave:-shave, shave:-shave]
mse = np.mean(pow(valid, 2))
return -10 * math.log10(mse)
def rgb2ycbcr(img, y_only=True):
"""from rgb space to ycbcr space"""
img.astype(np.float32)
if y_only:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
return rlt
def calc_ssim(img1, img2, scale):
"""calculate ssim value"""
def ssim(img1, img2):
C1 = (0.01 * 255) ** 2
C2 = (0.03 * 255) ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1 ** 2
mu2_sq = mu2 ** 2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
border = 0
if scale != 1:
border = scale
else:
border = scale + 6
img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 256.0 + 16.0
img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 256.0 + 16.0
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1_y = img1_y[border:h - border, border:w - border]
img2_y = img2_y[border:h - border, border:w - border]
if img1_y.ndim == 2:
return ssim(img1_y, img2_y)
if img1.ndim == 3:
if img1.shape[2] == 3:
ssims = []
for _ in range(3):
ssims.append(ssim(img1, img2))
return np.array(ssims).mean()
if img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError('Wrong input image dimensions.')

View File

@ -1,65 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""edsr_model"""
import mindspore.nn as nn
from src import common
class EDSR(nn.Cell):
"""EDSR"""
def __init__(self, args, conv=common.default_conv):
super(EDSR, self).__init__()
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
scale = args.scale[0]
act = nn.ReLU()
self.sub_mean = common.MeanShift(args.rgb_range)
self.add_mean = common.MeanShift(args.rgb_range, sign=1)
# define head module
m_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
m_body = [
common.ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
) for _ in range(n_resblocks)
]
m_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
m_tail = []
m_tail += common.Upsampler(conv, scale, n_feats)
m_tail.append(conv(n_feats, args.n_colors, kernel_size))
self.head = nn.SequentialCell(m_head)
self.body = nn.SequentialCell(m_body)
self.tail = nn.SequentialCell(m_tail)
def construct(self, x):
"""construct"""
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x

View File

@ -12,76 +12,179 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""edsr train wrapper"""
"""
#################utils for train.py and eval.py########################
"""
import os
import time
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.train.serialization import save_checkpoint
class Trainer():
"""Trainer"""
def __init__(self, args, loader, my_model):
self.args = args
self.scale = args.scale
self.trainloader = loader
self.model = my_model
self.model.set_train()
self.criterion = nn.L1Loss()
self.loss_history = []
self.begin_time = time.time()
self.optimizer = nn.Adam(self.model.trainable_params(), learning_rate=args.lr, loss_scale=1024.0)
self.loss_net = nn.WithLossCell(self.model, self.criterion)
self.net = nn.TrainOneStepCell(self.loss_net, self.optimizer)
def train(self, epoch):
"""Trainer"""
losses = 0
batch_idx = 0
for batch_idx, imgs in enumerate(self.trainloader):
lr = imgs["LR"]
hr = imgs["HR"]
lr = Tensor(lr, mstype.float32)
hr = Tensor(hr, mstype.float32)
t1 = time.time()
loss = self.net(lr, hr)
t2 = time.time()
losses += loss.asnumpy()
print('Epoch: %g, Step: %g , loss: %f, time: %f s ' % \
(epoch, batch_idx, loss.asnumpy(), t2 - t1), end='\n', flush=True)
print("the epoch loss is", losses / (batch_idx + 1), flush=True)
self.loss_history.append(losses / (batch_idx + 1))
print(self.loss_history)
t = time.time() - self.begin_time
t = int(t)
print(", running time: %gh%g'%g''"%(t//3600, (t-t//3600*3600)//60, t%60), flush=True)
os.makedirs(self.args.save, exist_ok=True)
if self.args.rank == 0 and (epoch+1)%10 == 0:
save_checkpoint(self.net, self.args.save + "model_" + str(self.epoch) + '.ckpt')
def update_learning_rate(self, epoch):
"""Update learning rates for all the networks; called at the end of every epoch.
:param epoch: current epoch
:type epoch: int
:param lr: learning rate of cyclegan
:type lr: float
:param niter: number of epochs with the initial learning rate
:type niter: int
:param niter_decay: number of epochs to linearly decay learning rate to zero
:type niter_decay: int
"""
self.epoch = epoch
print("*********** epoch: {} **********".format(epoch))
lr = self.args.lr / (2 ** ((epoch+1)//200))
self.adjust_lr('model', self.optimizer, lr)
print("*********************************")
def adjust_lr(self, name, optimizer, lr):
"""Adjust learning rate for the corresponding model.
:param name: name of model
:type name: str
:param optimizer: the optimizer of the corresponding model
:type optimizer: torch.optim
:param lr: learning rate to be adjusted
:type lr: float
"""
lr_param = optimizer.get_lr()
lr_param.assign_value(Tensor(lr, mstype.float32))
print('==> ' + name + ' learning rate: ', lr_param.asnumpy())
from mindspore import context
from mindspore.communication.management import init
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint
from model_utils.config import config
from model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
from .dataset import create_dataset_DIV2K
from .edsr import EDSR
def init_env(cfg):
"""
init env for mindspore
"""
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
device_num = get_device_num()
if cfg.device_target == "Ascend":
context.set_context(device_id=get_device_id())
if device_num > 1:
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
elif cfg.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
if device_num > 1:
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
elif cfg.device_target == "CPU":
pass
else:
raise ValueError("Unsupported platform.")
def init_dataset(cfg, dataset_type="train"):
"""
init DIV2K dataset
"""
ds_cfg = {
"dataset_path": cfg.data_path,
"scale": cfg.scale,
"lr_type": cfg.lr_type,
"batch_size": cfg.batch_size,
"patch_size": cfg.patch_size,
}
if cfg.dataset_name == "DIV2K":
dataset = create_dataset_DIV2K(config=ds_cfg,
dataset_type=dataset_type,
num_parallel_workers=10,
shuffle=dataset_type == "Train")
else:
raise ValueError("Unsupported dataset.")
return dataset
def init_net(cfg):
"""
init edsr network
"""
net = EDSR(scale=cfg.scale,
n_feats=cfg.n_feats,
kernel_size=cfg.kernel_size,
n_resblocks=cfg.n_resblocks,
n_colors=cfg.n_colors,
res_scale=cfg.res_scale,
rgb_range=cfg.rgb_range,
rgb_mean=cfg.rgb_mean,
rgb_std=cfg.rgb_std,)
if cfg.pre_trained:
pre_trained_path = os.path.join(cfg.output_path, cfg.pre_trained)
if len(cfg.pre_trained) >= 5 and cfg.pre_trained[:5] == "s3://":
pre_trained_path = cfg.pre_trained
import moxing as mox
mox.file.shift("os", "mox") # then system can read file from s3://
elif os.path.isfile(cfg.pre_trained):
pre_trained_path = cfg.pre_trained
elif os.path.isfile(pre_trained_path):
pass
else:
raise ValueError(f"pre_trained error: {cfg.pre_trained}")
print(f"loading pre_trained = {pre_trained_path}", flush=True)
param_dict = load_checkpoint(pre_trained_path)
net.load_pre_trained_param_dict(param_dict, strict=False)
return net
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
zip_isexist = zipfile.is_zipfile(zip_file)
zip_name = os.path.basename(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
data_print = int(data_num / 4) if data_num > 4 else 1
len_data_num = len(str(data_num))
for i, _file in enumerate(fz.namelist()):
if i % data_print == 0:
print("[{1:>{0}}/{2:>{0}}] {3:>2}% const time: {4:0>2}:{5:0>2} unzipping {6}".format(
len_data_num,
i,
data_num,
int(i / data_num * 100),
int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60),
zip_name,
flush=True))
fz.extract(_file, save_dir)
print(" finish const time: {:0>2}:{:0>2} unzipping {}".format(
int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60),
zip_name,
flush=True))
else:
print("{} is not zip.".format(zip_name), flush=True)
if config.enable_modelarts and config.need_unzip_in_modelarts:
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
for ufile in config.need_unzip_files:
zip_file = os.path.join(config.data_path, ufile)
save_dir = os.path.dirname(zip_file)
unzip(zip_file, save_dir)
print("===Finish extract data synchronization===", flush=True)
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data.".format(get_device_id()), flush=True)
config.ckpt_save_dir = os.path.join(config.output_path, config.ckpt_save_dir)
def do_eval(eval_network, ds_val, metrics, cur_epoch=None):
"""
do eval for psnr and save hr, sr
"""
eval_network.set_train(False)
total_step = ds_val.get_dataset_size()
setw = len(str(total_step))
begin = time.time()
step_begin = time.time()
rank_id = get_rank_id()
for i, (lr, hr) in enumerate(ds_val):
sr = eval_network(lr)
_ = [m.update(sr, hr) for m in metrics.values()]
result = {k: m.eval(sync=False) for k, m in metrics.items()}
result["time"] = time.time() - step_begin
step_begin = time.time()
print(f"[{i+1:>{setw}}/{total_step:>{setw}}] rank = {rank_id} result = {result}", flush=True)
result = {k: m.eval(sync=True) for k, m in metrics.items()}
result["time"] = time.time() - begin
if cur_epoch is not None:
result["epoch"] = cur_epoch
if rank_id == 0:
print(f"evaluation result = {result}", flush=True)
eval_network.set_train(True)
return result

View File

@ -12,62 +12,138 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""edsr train script"""
import os
from mindspore import context
from mindspore import dataset as ds
"""
#################train EDSR example on DIV2K########################
"""
import numpy as np
import mindspore.nn as nn
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.common import set_seed
from mindspore import Tensor
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Callback
from mindspore.train.model import Model
from src.args import args
from src.data.div2k import DIV2K
from src.model import EDSR
from mindspore.common import set_seed
def train_net():
"""train edsr"""
set_seed(1)
device_id = int(os.getenv('DEVICE_ID', '0'))
rank_id = int(os.getenv('RANK_ID', '0'))
device_num = int(os.getenv('RANK_SIZE', '1'))
# if distribute:
if device_num > 1:
init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=device_num, gradients_mean=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
train_dataset.set_scale(args.task_id)
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num,
shard_id=rank_id, shuffle=True)
train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
net_m = EDSR(args)
print("Init net successfully")
if args.ckpt_path:
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(net_m, param_dict)
print("Load net weight successfully")
step_size = train_de_dataset.get_dataset_size()
lr = []
for i in range(0, args.epochs):
cur_lr = args.lr / (2 ** ((i + 1)//200))
lr.extend([cur_lr] * step_size)
opt = nn.Adam(net_m.trainable_params(), learning_rate=lr, loss_scale=args.loss_scale)
loss = nn.L1Loss()
model = Model(net_m, loss_fn=loss, optimizer=opt)
time_cb = TimeMonitor(data_size=step_size)
from src.metric import PSNR
from src.utils import init_env, init_dataset, init_net, modelarts_pre_process, do_eval
from model_utils.config import config
from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_rank_id, get_device_num
set_seed(2021)
def lr_steps_edsr(lr, milestones, gamma, epoch_size, steps_per_epoch, last_epoch=None):
lr_each_step = []
step_begin_epoch = [0] + milestones[:-1]
step_end_epoch = milestones[1:] + [epoch_size]
for begin, end in zip(step_begin_epoch, step_end_epoch):
lr_each_step += [lr] * (end - begin) * steps_per_epoch
lr *= gamma
if last_epoch is not None:
lr_each_step = lr_each_step[last_epoch * steps_per_epoch:]
return np.array(lr_each_step).astype(np.float32)
def init_opt(cfg, net):
"""
init opt to train edsr
"""
lr = lr_steps_edsr(lr=cfg.learning_rate, milestones=cfg.milestones, gamma=cfg.gamma,
epoch_size=cfg.epoch_size, steps_per_epoch=cfg.steps_per_epoch, last_epoch=None)
loss_scale = 1.0 if cfg.amp_level == "O0" else cfg.loss_scale
if cfg.opt_type == "Adam":
opt = nn.Adam(params=filter(lambda x: x.requires_grad, net.get_parameters()),
learning_rate=Tensor(lr),
weight_decay=cfg.weight_decay,
loss_scale=loss_scale)
elif cfg.opt_type == "SGD":
opt = nn.SGD(params=filter(lambda x: x.requires_grad, net.get_parameters()),
learning_rate=Tensor(lr),
weight_decay=cfg.weight_decay,
momentum=cfg.momentum,
dampening=cfg.dampening if hasattr(cfg, "dampening") else 0.0,
nesterov=cfg.nesterov if hasattr(cfg, "nesterov") else False,
loss_scale=loss_scale)
else:
raise ValueError("Unsupported optimizer.")
return opt
class EvalCallBack(Callback):
"""
eval callback
"""
def __init__(self, eval_network, ds_val, eval_epoch_frq, epoch_size, metrics, result_evaluation=None):
self.eval_network = eval_network
self.ds_val = ds_val
self.eval_epoch_frq = eval_epoch_frq
self.epoch_size = epoch_size
self.result_evaluation = result_evaluation
self.metrics = metrics
self.best_result = None
self.eval_network.set_train(False)
def epoch_end(self, run_context):
"""
do eval in epoch end
"""
cb_param = run_context.original_args()
cur_epoch = cb_param.cur_epoch_num
if cur_epoch % self.eval_epoch_frq == 0 or cur_epoch == self.epoch_size:
result = do_eval(self.eval_network, self.ds_val, self.metrics, cur_epoch=cur_epoch)
if self.best_result is None or self.best_result["psnr"] < result["psnr"]:
self.best_result = result
if get_rank_id() == 0:
print(f"best evaluation result = {self.best_result}", flush=True)
if isinstance(self.result_evaluation, dict):
for k, v in result.items():
r_list = self.result_evaluation.get(k)
if r_list is None:
r_list = []
self.result_evaluation[k] = r_list
r_list.append(v)
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_train():
"""
run train
"""
print(config)
cfg = config
init_env(cfg)
ds_train = init_dataset(cfg, "train")
ds_val = init_dataset(cfg, "valid")
net = init_net(cfg)
cfg.steps_per_epoch = ds_train.get_dataset_size()
opt = init_opt(cfg, net)
loss = nn.L1Loss(reduction='mean')
eval_net = net
model = Model(net, loss_fn=loss, optimizer=opt, amp_level=cfg.amp_level)
metrics = {
"psnr": PSNR(rgb_range=cfg.rgb_range, shave=True),
}
eval_cb = EvalCallBack(eval_net, ds_val, cfg.eval_epoch_frq, cfg.epoch_size, metrics=metrics)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.steps_per_epoch * cfg.save_epoch_frq,
keep_checkpoint_max=cfg.keep_checkpoint_max)
time_cb = TimeMonitor()
ckpoint_cb = ModelCheckpoint(prefix=f"EDSR_x{cfg.scale}_" + cfg.dataset_name, directory=cfg.ckpt_save_dir,
config=config_ck)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
config_ck = CheckpointConfig(save_checkpoint_steps=args.ckpt_save_interval * step_size,
keep_checkpoint_max=args.ckpt_save_max)
ckpt_cb = ModelCheckpoint(prefix="edsr", directory=args.ckpt_save_path, config=config_ck)
if device_id == 0:
cb += [ckpt_cb]
model.train(args.epochs, train_de_dataset, callbacks=cb, dataset_sink_mode=True)
cbs = [time_cb, ckpoint_cb, loss_cb, eval_cb]
if get_device_num() > 1 and get_rank_id() != 0:
cbs = [time_cb, loss_cb, eval_cb]
model.train(cfg.epoch_size, ds_train, dataset_sink_mode=cfg.dataset_sink_mode, callbacks=cbs)
print("train success", flush=True)
if __name__ == "__main__":
train_net()
if __name__ == '__main__':
run_train()