commit
5a851daf2f
|
@ -0,0 +1,126 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
device_target: "Ascend"
|
||||
|
||||
# ==============================================================================
|
||||
# train options
|
||||
amp_level: "O3"
|
||||
loss_scale: 1000.0 # for ['O2', 'O3', 'auto']
|
||||
keep_checkpoint_max: 60
|
||||
save_epoch_frq: 100
|
||||
ckpt_save_dir: "./ckpt/"
|
||||
epoch_size: 6000
|
||||
|
||||
# eval options
|
||||
eval_epoch_frq: 20
|
||||
self_ensemble: True
|
||||
save_sr: True
|
||||
|
||||
# Adam opt options
|
||||
opt_type: Adam
|
||||
weight_decay: 0.0
|
||||
|
||||
# learning rate options
|
||||
learning_rate: 0.0001
|
||||
milestones: [4000]
|
||||
gamma: 0.5
|
||||
|
||||
# dataset options
|
||||
dataset_name: "DIV2K"
|
||||
lr_type: "bicubic"
|
||||
batch_size: 2
|
||||
patch_size: 192
|
||||
scale: 4
|
||||
dataset_sink_mode: True
|
||||
need_unzip_in_modelarts: False
|
||||
need_unzip_files:
|
||||
- "DIV2K_train_HR.zip"
|
||||
- "DIV2K_train_LR_bicubic_X2.zip"
|
||||
- "DIV2K_train_LR_bicubic_X3.zip"
|
||||
- "DIV2K_train_LR_bicubic_X4.zip"
|
||||
- "DIV2K_train_LR_unknown_X2.zip"
|
||||
- "DIV2K_train_LR_unknown_X3.zip"
|
||||
- "DIV2K_train_LR_unknown_X4.zip"
|
||||
- "DIV2K_valid_HR.zip"
|
||||
- "DIV2K_valid_LR_bicubic_X2.zip"
|
||||
- "DIV2K_valid_LR_bicubic_X3.zip"
|
||||
- "DIV2K_valid_LR_bicubic_X4.zip"
|
||||
- "DIV2K_valid_LR_unknown_X2.zip"
|
||||
- "DIV2K_valid_LR_unknown_X3.zip"
|
||||
- "DIV2K_valid_LR_unknown_X4.zip"
|
||||
|
||||
# net options
|
||||
pre_trained: ""
|
||||
rgb_range: 255
|
||||
rgb_mean: [0.4488, 0.4371, 0.4040]
|
||||
rgb_std: [1.0, 1.0, 1.0]
|
||||
n_colors: 3
|
||||
n_feats: 256
|
||||
kernel_size: 3
|
||||
n_resblocks: 32
|
||||
res_scale: 0.1
|
||||
|
||||
|
||||
---
|
||||
# helper
|
||||
|
||||
enable_modelarts: "set True if run in modelarts, default: False"
|
||||
# Url for modelarts
|
||||
data_url: "modelarts data path"
|
||||
train_url: "modelarts code path"
|
||||
checkpoint_url: "modelarts checkpoint save path"
|
||||
# Path for local
|
||||
data_path: "local data path, data will be download from 'data_url', default: /cache/data"
|
||||
output_path: "local output path, checkpoint will be upload to 'checkpoint_url', default: /cache/train"
|
||||
device_target: "choice from ['Ascend'], default: Ascend"
|
||||
|
||||
# ==============================================================================
|
||||
# train options
|
||||
amp_level: "choice from ['O0', 'O2', 'O3', 'auto'], default: O3"
|
||||
loss_scale: "loss scale will be used except 'O0', default: 1000.0"
|
||||
keep_checkpoint_max: "max number of checkpoints to be saved, defalue: 60"
|
||||
save_epoch_frq: "frequency to save checkpoint, defalue: 100"
|
||||
ckpt_save_dir: "the relative path to save checkpoint, root path is 'output_path', defalue: ./ckpt/"
|
||||
epoch_size: "the number of training epochs, defalue: 6000"
|
||||
|
||||
# eval options
|
||||
eval_epoch_frq: "frequency to evaluate model, defalue: 20"
|
||||
self_ensemble: "set True if wanna do self-ensemble while evaluating, defalue: True"
|
||||
save_sr: "set True if wanna save sr and hr image while evaluating, defalue: True"
|
||||
|
||||
# opt options
|
||||
opt_type: "optimizer type, choice from ['Adam'], defalue: Adam"
|
||||
weight_decay: "weight_decay for optimizer, defalue: 0.0"
|
||||
|
||||
# learning rate options
|
||||
learning_rate: "learning rate, defalue: 0.0001"
|
||||
milestones: "the key epoch to do a gamma decay, defalue: [4000]"
|
||||
gamma: "gamma decay rate, defalue: 0.5"
|
||||
|
||||
# dataset options
|
||||
dataset_name: "dataset name, defalue: DIV2K"
|
||||
lr_type: "lr image degeneration type, choice from ['bicubic', 'unknown'], defalue: bicubic"
|
||||
batch_size: "batch size for training; total batch size = 16 is recommended, defalue: 2"
|
||||
patch_size: "cut hr images into patch size for training, lr images auto-adjust by 'scale', defalue: 192"
|
||||
scale: "scale for super resolution reconstruction, choice from [2,3,4], defalue: 4"
|
||||
dataset_sink_mode: "set True if wanna using dataset sink mode, defalue: True"
|
||||
need_unzip_in_modelarts: "set True if wanna unzip data after download data from s3, defalue: False"
|
||||
need_unzip_files: "list of zip files to unzip, only work while 'need_unzip_in_modelarts'=True"
|
||||
|
||||
# net options
|
||||
pre_trained: "load pre-trained model, x2/x3/x4 models can be loaded for each other, choice from [[S3_ABS_PATH], [RELATIVE_PATH below 'output_path'], [LOCAL_ABS_PATH], ''], defalue: ''"
|
||||
rgb_range: "pix value range, defalue: 255"
|
||||
rgb_mean: "rgb mean, defalue: [0.4488, 0.4371, 0.4040]"
|
||||
rgb_std: "rgb standard deviation, defalue: [1.0, 1.0, 1.0]"
|
||||
n_colors: "the number of RGB image channels, defalue: 3"
|
||||
n_feats: "the number of output features for each Conv2d, defalue: 256"
|
||||
kernel_size: "kernel size for Conv2d, defalue: 3"
|
||||
n_resblocks: "the number of resblocks, defalue: 32"
|
||||
res_scale: "zoom scale of res branch, defalue: 0.1"
|
|
@ -1,4 +1,6 @@
|
|||
目录
|
||||
# 目录
|
||||
|
||||
[View English](./README.md)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
|
@ -6,22 +8,27 @@
|
|||
- [EDSR描述](#EDSR描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [特性](#特性)
|
||||
- [混合精度](#混合精度)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练)
|
||||
- [分布式训练](#分布式训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [导出](#导出)
|
||||
- [导出脚本](#导出脚本)
|
||||
- [推理过程](#推理过程)
|
||||
- [推理](#推理)
|
||||
- [在昇腾310上使用DIV2K数据集进行推理](#在昇腾310上使用DIV2K数据集进行推理)
|
||||
- [在昇腾310上使用其他数据集进行推理](#在昇腾310上使用其他数据集进行推理)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [训练性能](#训练性能)
|
||||
- [DIV2K上的EDSR](#DIV2K上的EDSR)
|
||||
- [DIV2K上的训练2倍/3倍/4倍超分辨率重建的EDSR](#DIV2K上的训练2倍/3倍/4倍超分辨率重建的EDSR)
|
||||
- [评估性能](#评估性能)
|
||||
- [Set5,Set14,B100,Urban100上的EDSR](#Set5,Set14,B100,Urban100上的EDSR)
|
||||
- [DIV2K上的评估2倍/3倍/4倍超分辨率重建的EDSR](#DIV2K上的评估2倍/3倍/4倍超分辨率重建的EDSR)
|
||||
- [推理性能](#推理性能)
|
||||
- [DIV2K上的推理2倍/3倍/4倍超分辨率重建的EDSR](#DIV2K上的推理2倍/3倍/4倍超分辨率重建的EDSR)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
|
@ -29,113 +36,144 @@
|
|||
|
||||
# EDSR描述
|
||||
|
||||
EDSR是2017年提出的32层深度网络,在2017年图像恢复和增强的新趋势研讨会上的超分挑战(NTIRE2017 Super-Resolution Challenge)中获得第一名。 EDSR,相比于SRResNet减少了每个残差块中的batch normalization层,SRResNet相对于原本的ResNet则在每个残差块的出口减去了ReLU层.
|
||||
[论文](https://arxiv.org/abs/1707.02921):Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **"Enhanced Deep Residual Networks for Single Image Super-Resolution,"** *2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**.
|
||||
增强的深度超分辨率网络(EDSR)是2017年提出的单图超分辨重建网络,在NTIRE2017超分辨重建比赛中获取第一名。它通过删除传统剩余网络中不必要的模块(BatchNorm),扩大模型的大小,同时应用了稳定训练的方法进行优化,显著提升了性能。
|
||||
|
||||
[论文](https://arxiv.org/pdf/1707.02921.pdf):Lim B , Son S , Kim H , et al. Enhanced Deep Residual Networks for Single Image Super-Resolution[C]// 2017 IEEE Conference on Computer Vision and Pattern Recognition Workshops (CVPRW). IEEE, 2017.
|
||||
|
||||
# 模型架构
|
||||
|
||||
EDSR先经过1次卷积层,再串联32个残差模块,再经过1次卷积层,最后上采样并卷积。
|
||||
EDSR是由多个优化后的residual blocks串联而成,相比原始版本的residual blocks,EDSR的residual blocks删除了BatchNorm层和最后一个ReLU层。删除BatchNorm使网络降低了40%的显存使用率和获得更快的计算效率,从而可以增加网络深度和宽度。EDSR的主干模式使用32个residual blocks堆叠而成,每个卷积层的卷积核数量256个,Residual scaling是0.1,损失函数是L1。
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集:[DIV2K](<http://www.vision.ee.ethz.ch/~timofter/publications/Agustsson-CVPRW-2017.pdf>)
|
||||
使用的数据集:[DIV2K](<https://data.vision.ee.ethz.ch/cvl/DIV2K/>)
|
||||
|
||||
- 数据集大小:7.11G,
|
||||
- 数据集大小:7.11G,共1000组(HR,LRx2,LRx3,LRx4)有效彩色图像
|
||||
- 训练集:6.01G,共800组图像
|
||||
- 验证集:783.68M,共100组图像
|
||||
- 测试集:349.53M,共100组图像(无HR图)
|
||||
- 数据格式:PNG图片文件文件
|
||||
- 注:数据将在src/dataset.py中处理。
|
||||
|
||||
- 训练集:共800张图像(采用了前800张进行训练)
|
||||
- 测试集:共100张图像
|
||||
# 特性
|
||||
|
||||
- 数据格式:png文件
|
||||
## 混合精度
|
||||
|
||||
- 注:数据将在src/data/DIV2K.py中处理。
|
||||
|
||||
```shell
|
||||
DIV2K
|
||||
├── DIV2K_test_LR_bicubic
|
||||
│ ├── X2
|
||||
│ │ ├── 0901x2.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 1000x2.png
|
||||
│ ├── X3
|
||||
│ │ ├── 0901x3.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 1000x3.png
|
||||
│ └── X4
|
||||
│ ├── 0901x4.png
|
||||
│ ├─ ...
|
||||
│ └── 1000x4.png
|
||||
├── DIV2K_test_LR_unknown
|
||||
│ ├── X2
|
||||
│ │ ├── 0901x2.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 1000x2.png
|
||||
│ ├── X3
|
||||
│ │ ├── 0901x3.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 1000x3.png
|
||||
│ └── X4
|
||||
│ ├── 0901x4.png
|
||||
│ ├─ ...
|
||||
│ └── 1000x4.png
|
||||
├── DIV2K_train_HR
|
||||
│ ├── 0001.png
|
||||
│ ├─ ...
|
||||
│ └── 0900.png
|
||||
├── DIV2K_train_LR_bicubic
|
||||
│ ├── X2
|
||||
│ │ ├── 0001x2.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 0900x2.png
|
||||
│ ├── X3
|
||||
│ │ ├── 0001x3.png
|
||||
│ │ ├─ ...
|
||||
│ │ └── 0900x3.png
|
||||
│ └── X4
|
||||
│ ├── 0001x4.png
|
||||
│ ├─ ...
|
||||
│ └── 0900x4.png
|
||||
└── DIV2K_train_LR_unknown
|
||||
├── X2
|
||||
│ ├── 0001x2.png
|
||||
│ ├─ ...
|
||||
│ └── 0900x2.png
|
||||
├── X3
|
||||
│ ├── 0001x3.png
|
||||
│ ├─ ...
|
||||
│ └── 0900x3.png
|
||||
└── X4
|
||||
├── 0001x4.png
|
||||
├─ ...
|
||||
└── 0900x4.png
|
||||
```
|
||||
采用[混合精度](https://www.mindspore.cn/docs/programming_guide/zh-CN/master/enable_mixed_precision.html?highlight=%E6%B7%B7%E5%90%88%E7%B2%BE%E5%BA%A6)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
|
||||
以FP16算子为例,如果输入数据类型为FP32,MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志,搜索“reduce precision”查看精度降低的算子。
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 使用ascend处理器来搭建硬件环境。
|
||||
- 使用Ascend处理器来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
|
||||
- [MindSpore教程](https://www.mindspore.cn/docs/programming_guide/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估。对于分布式训练,需要提前创建JSON格式的hccl配置文件。请遵循以下链接中的说明:
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>
|
||||
|
||||
```shell
|
||||
#单卡训练
|
||||
sh run_ascend_standalone.sh [TRAIN_DATA_DIR]
|
||||
```
|
||||
- Ascend-910处理器环境运行单卡训练DIV2K
|
||||
|
||||
```shell
|
||||
#分布式训练
|
||||
sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
|
||||
```
|
||||
```python
|
||||
# 运行训练示例(EDSR(x2) in the paper)
|
||||
python train.py --batch_size 16 --scale 2 --config_path DIV2K_config.yaml > train.log 2>&1 &
|
||||
# 运行训练示例(EDSR(x3) in the paper - from EDSR(x2))
|
||||
python train.py --batch_size 16 --scale 3 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x2 model path] train.log 2>&1 &
|
||||
# 运行训练示例(EDSR(x4) in the paper - from EDSR(x2))
|
||||
python train.py --batch_size 16 --scale 4 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x2 model path] train.log 2>&1 &
|
||||
```
|
||||
|
||||
- Ascend-910处理器环境运行8卡训练DIV2K
|
||||
|
||||
```python
|
||||
# 运行分布式训练示例(EDSR(x2) in the paper)
|
||||
bash scripts/run_train.sh rank_table.json --scale 2 --config_path DIV2K_config.yaml
|
||||
# 运行分布式训练示例(EDSR(x3) in the paper)
|
||||
bash scripts/run_train.sh rank_table.json --scale 3 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x2 model path]
|
||||
# 运行分布式训练示例(EDSR(x4) in the paper)
|
||||
bash scripts/run_train.sh rank_table.json --scale 4 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x2 model path]
|
||||
```
|
||||
|
||||
- Ascend-910处理器环境运行单卡评估DIV2K
|
||||
|
||||
```python
|
||||
# 运行评估示例(EDSR(x2) in the paper)
|
||||
python eval.py --scale 2 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x2 model path] > train.log 2>&1 &
|
||||
# 运行评估示例(EDSR(x3) in the paper)
|
||||
python eval.py --scale 3 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x3 model path] > train.log 2>&1 &
|
||||
# 运行评估示例(EDSR(x4) in the paper)
|
||||
python eval.py --scale 4 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x4 model path] > train.log 2>&1 &
|
||||
```
|
||||
|
||||
- Ascend-910处理器环境运行8卡评估DIV2K
|
||||
|
||||
```python
|
||||
# 运行分布式评估示例(EDSR(x2) in the paper)
|
||||
bash scripts/run_eval.sh rank_table.json --scale 2 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x2 model path]
|
||||
# 运行分布式评估示例(EDSR(x3) in the paper)
|
||||
bash scripts/run_eval.sh rank_table.json --scale 3 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x3 model path]
|
||||
# 运行分布式评估示例(EDSR(x4) in the paper)
|
||||
bash scripts/run_eval.sh rank_table.json --scale 4 --config_path DIV2K_config.yaml --pre_trained [pre-trained EDSR_x4 model path]
|
||||
```
|
||||
|
||||
- Ascend-910处理器环境运行单卡评估benchmark
|
||||
|
||||
```python
|
||||
# 运行评估示例(EDSR(x2) in the paper)
|
||||
python eval.py --scale 2 --config_path benchmark_config.yaml --pre_trained [pre-trained EDSR_x2 model path] > train.log 2>&1 &
|
||||
# 运行评估示例(EDSR(x3) in the paper)
|
||||
python eval.py --scale 3 --config_path benchmark_config.yaml --pre_trained [pre-trained EDSR_x3 model path] > train.log 2>&1 &
|
||||
# 运行评估示例(EDSR(x4) in the paper)
|
||||
python eval.py --scale 4 --config_path benchmark_config.yaml --pre_trained [pre-trained EDSR_x4 model path] > train.log 2>&1 &
|
||||
```
|
||||
|
||||
- Ascend-910处理器环境运行8卡评估benchmark
|
||||
|
||||
```python
|
||||
# 运行分布式评估示例(EDSR(x2) in the paper)
|
||||
bash scripts/run_eval.sh rank_table.json --scale 2 --config_path benchmark_config.yaml --pre_trained [pre-trained EDSR_x2 model path]
|
||||
# 运行分布式评估示例(EDSR(x3) in the paper)
|
||||
bash scripts/run_eval.sh rank_table.json --scale 3 --config_path benchmark_config.yaml --pre_trained [pre-trained EDSR_x3 model path]
|
||||
# 运行分布式评估示例(EDSR(x4) in the paper)
|
||||
bash scripts/run_eval.sh rank_table.json --scale 4 --config_path benchmark_config.yaml --pre_trained [pre-trained EDSR_x4 model path]
|
||||
```
|
||||
|
||||
- Ascend-310处理器环境运行单卡评估DIV2K
|
||||
|
||||
```python
|
||||
# 运行推理命令
|
||||
bash scripts/run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [SCALE] [LOG_FILE] [DEVICE_ID]
|
||||
# 运行推理示例(EDSR(x2) in the paper)
|
||||
bash scripts/run_infer_310.sh ./mindir/EDSR_x2_DIV2K-6000_50_InputSize1020.mindir ./DIV2K 2 ./infer_x2.log 0
|
||||
# 运行推理示例(EDSR(x3) in the paper)
|
||||
bash scripts/run_infer_310.sh ./mindir/EDSR_x3_DIV2K-6000_50_InputSize680.mindir ./DIV2K 3 ./infer_x3.log 0
|
||||
# 运行推理示例(EDSR(x4) in the paper)
|
||||
bash scripts/run_infer_310.sh ./mindir/EDSR_x4_DIV2K-6000_50_InputSize510.mindir ./DIV2K 4 ./infer_x4.log 0
|
||||
```
|
||||
|
||||
- 在 ModelArts 上训练 DIV2K 数据集
|
||||
|
||||
如果你想在modelarts上运行,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/)
|
||||
|
||||
```python
|
||||
#评估
|
||||
sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
|
||||
# (1) 选择上传代码到 S3 桶
|
||||
# 选择代码目录/s3_path_to_code/EDSR/
|
||||
# 选择启动文件/s3_path_to_code/EDSR/train.py
|
||||
# (2) 在网页上设置参数, DIV2K_config.yaml中的参数均可以在网页上配置
|
||||
# scale = 2
|
||||
# config_path = /local_path_to_code/DIV2K_config.yaml
|
||||
# enable_modelarts = True
|
||||
# pre_trained = [模型s3地址] 或者 [不设置]
|
||||
# [其他参数] = [参数值]
|
||||
# (3) 上传DIV2K数据集到 S3 桶上, 配置"训练数据集"路径,如果未解压,可以在(2)中配置
|
||||
# need_unzip_in_modelarts = True
|
||||
# (4) 在网页上设置"训练输出文件路径"、"作业日志路径"等
|
||||
# (5) 选择8卡/单卡机器,创建训练作业
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
@ -144,150 +182,218 @@ sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
|
|||
|
||||
```bash
|
||||
├── model_zoo
|
||||
├── README.md // 所有模型相关说明
|
||||
├── EDSR
|
||||
├── README_CN.md //自述文件
|
||||
├── eval.py //评估脚本
|
||||
├── export.py //导出脚本
|
||||
├── script
|
||||
│ ├── run_ascend_distribute.sh //Ascend分布式训练shell脚本
|
||||
│ ├── run_ascend_standalone.sh //Ascend单卡训练shell脚本
|
||||
│ └── run_eval.sh //eval验证shell脚本
|
||||
├── README_CN.md // EDSR说明
|
||||
├── model_utils // 上云的工具脚本
|
||||
├── DIV2K_config.yaml // EDSR参数
|
||||
├── scripts
|
||||
│ ├──run_train.sh // 分布式到Ascend的shell脚本
|
||||
│ ├──run_eval.sh // Ascend评估的shell脚本
|
||||
│ ├──run_infer_310.sh // Ascend-310推理shell脚本
|
||||
├── src
|
||||
│ ├── args.py //超参数
|
||||
│ ├── common.py //公共网络模块
|
||||
│ ├── data
|
||||
│ │ ├── common.py //公共数据集
|
||||
│ │ ├── div2k.py //div2k数据集
|
||||
│ │ └── srdata.py //所有数据集
|
||||
│ ├── metrics.py //PSNR和SSIM计算器
|
||||
│ ├── model.py //EDSR网络
|
||||
│ └── utils.py //训练脚本
|
||||
└── train.py //训练脚本
|
||||
│ ├──dataset.py // 创建数据集
|
||||
│ ├──edsr.py // edsr网络架构
|
||||
│ ├──config.py // 参数配置
|
||||
│ ├──metric.py // 评估指标
|
||||
│ ├──utils.py // train.py/eval.py公用的代码段
|
||||
├── train.py // 训练脚本
|
||||
├── eval.py // 评估脚本
|
||||
├── export.py // 将checkpoint文件导出到air/mindir
|
||||
├── preprocess.py // Ascend-310推理的数据预处理脚本
|
||||
├── ascend310_infer
|
||||
│ ├──src // 实现Ascend-310推理源代码
|
||||
│ ├──inc // 实现Ascend-310推理源代码
|
||||
│ ├──build.sh // 构建Ascend-310推理程序的shell脚本
|
||||
│ ├──CMakeLists.txt // 构建Ascend-310推理程序的CMakeLists
|
||||
├── postprocess.py // Ascend-310推理的数据后处理脚本
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
主要参数如下:
|
||||
在DIV2K_config.yaml中可以同时配置训练参数和评估参数。benchmark_config.yaml中的同名参数是一样的定义。
|
||||
|
||||
```python
|
||||
-h, --help show this help message and exit
|
||||
--dir_data DIR_DATA dataset directory
|
||||
--data_train DATA_TRAIN
|
||||
train dataset name
|
||||
--data_test DATA_TEST
|
||||
test dataset name
|
||||
--data_range DATA_RANGE
|
||||
train/test data range
|
||||
--ext EXT dataset file extension
|
||||
--scale SCALE super resolution scale
|
||||
--patch_size PATCH_SIZE
|
||||
output patch size
|
||||
--rgb_range RGB_RANGE
|
||||
maximum value of RGB
|
||||
--n_colors N_COLORS number of color channels to use
|
||||
--no_augment do not use data augmentation
|
||||
--model MODEL model name
|
||||
--n_resblocks N_RESBLOCKS
|
||||
number of residual blocks
|
||||
--n_feats N_FEATS number of feature maps
|
||||
--res_scale RES_SCALE
|
||||
residual scaling
|
||||
--test_every TEST_EVERY
|
||||
do test per every N batches
|
||||
--epochs EPOCHS number of epochs to train
|
||||
--batch_size BATCH_SIZE
|
||||
input batch size for training
|
||||
--test_only set this option to test the model
|
||||
--lr LR learning rate
|
||||
--ckpt_save_path CKPT_SAVE_PATH
|
||||
path to save ckpt
|
||||
--ckpt_save_interval CKPT_SAVE_INTERVAL
|
||||
save ckpt frequency, unit is epoch
|
||||
--ckpt_save_max CKPT_SAVE_MAX
|
||||
max number of saved ckpt
|
||||
--ckpt_path CKPT_PATH
|
||||
path of saved ckpt
|
||||
--task_id TASK_ID
|
||||
- 可以使用以下语句可以打印配置说明
|
||||
|
||||
```python
|
||||
python train.py --config_path DIV2K_config.yaml --help
|
||||
```
|
||||
|
||||
- 可以直接查看DIV2K_config.yaml内的配置说明,说明如下
|
||||
|
||||
```yaml
|
||||
enable_modelarts: "在云道运行则需要配置为True, default: False"
|
||||
|
||||
data_url: "云道数据路径"
|
||||
train_url: "云道代码路径"
|
||||
checkpoint_url: "云道保存的路径"
|
||||
|
||||
data_path: "运行机器的数据路径,由脚本从云道数据路径下载,default: /cache/data"
|
||||
output_path: "运行机器的输出路径,由脚本从本地上传至checkpoint_url,default: /cache/train"
|
||||
device_target: "可选['Ascend'],default: Ascend"
|
||||
|
||||
amp_level: "可选['O0', 'O2', 'O3', 'auto'],default: O3"
|
||||
loss_scale: "除了O0外,其他混合精度时会做loss放缩,default: 1000.0"
|
||||
keep_checkpoint_max: "最多保存多少个ckpt, defalue: 60"
|
||||
save_epoch_frq: "每隔多少epoch保存ckpt一次, defalue: 100"
|
||||
ckpt_save_dir: "保存的本地相对路径,根目录是output_path, defalue: ./ckpt/"
|
||||
epoch_size: "训练多少个epoch, defalue: 6000"
|
||||
|
||||
eval_epoch_frq: "训练时每隔多少epoch执行一次验证,defalue: 20"
|
||||
self_ensemble: "验证时执行self_ensemble,仅在eval.py中使用, defalue: True"
|
||||
save_sr: "验证时保存sr和hr图片,仅在eval.py中使用, defalue: True"
|
||||
|
||||
opt_type: "优化器类型,可选['Adam'],defalue: Adam"
|
||||
weight_decay: "优化器权重衰减参数,defalue: 0.0"
|
||||
|
||||
learning_rate: "学习率,defalue: 0.0001"
|
||||
milestones: "学习率衰减的epoch节点列表,defalue: [4000]"
|
||||
gamma: "学习率衰减率,defalue: 0.5"
|
||||
|
||||
dataset_name: "数据集名称,defalue: DIV2K"
|
||||
lr_type: "lr图的退化方式,可选['bicubic', 'unknown'],defalue: bicubic"
|
||||
batch_size: "为了保证效果,建议8卡用2,单卡用16,defalue: 2"
|
||||
patch_size: "训练时候的裁剪HR图大小,LR图会依据scale调整裁剪大小,defalue: 192"
|
||||
scale: "模型的超分辨重建的尺度,可选[2,3,4], defalue: 4"
|
||||
dataset_sink_mode: "训练使用数据下沉模式,defalue: True"
|
||||
need_unzip_in_modelarts: "从s3下载数据后加压数据,defalue: False"
|
||||
need_unzip_files: "需要解压的数据列表, need_unzip_in_modelarts=True时才起作用"
|
||||
|
||||
pre_trained: "加载预训练模型,x2/x3/x4倍可以相互加载,可选[[s3绝对地址], [output_path下相对地址], [本地机器绝对地址], ''],defalue: ''"
|
||||
rgb_range: "图片像素的范围,defalue: 255"
|
||||
rgb_mean: "图片RGB均值,defalue: [0.4488, 0.4371, 0.4040]"
|
||||
rgb_std: "图片RGB方差,defalue: [1.0, 1.0, 1.0]"
|
||||
n_colors: "RGB图片3通道,defalue: 3"
|
||||
n_feats: "每个卷积层的输出特征数量,defalue: 256"
|
||||
kernel_size: "卷积核大小,defalue: 3"
|
||||
n_resblocks: "resblocks数量,defalue: 32"
|
||||
res_scale: "res的分支的系数,defalue: 0.1"
|
||||
```
|
||||
|
||||
## 导出
|
||||
|
||||
在运行推理之前我们需要先导出模型。Air模型只能在昇腾910环境上导出,mindir可以在任意环境上导出。batch_size只支持1。
|
||||
|
||||
### 导出脚本
|
||||
|
||||
```shell
|
||||
python export.py --config_path DIV2K_config.yaml --output_path [dir to save model] --scale [SCALE] --pre_trained [pre-trained EDSR_x[SCALE] model path]
|
||||
```
|
||||
|
||||
## 训练过程
|
||||
## 推理过程
|
||||
|
||||
### 训练
|
||||
### 推理
|
||||
|
||||
- Ascend处理器环境运行
|
||||
#### 在昇腾310上使用DIV2K数据集进行推理
|
||||
|
||||
```bash
|
||||
sh run_ascend_standalone.sh [TRAIN_DATA_DIR]
|
||||
- 推理脚本
|
||||
|
||||
```shell
|
||||
bash scripts/run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [SCALE] [LOG_FILE] [DEVICE_ID]
|
||||
```
|
||||
|
||||
如果数据集保存路径为G:\DIV2K,`TRAIN_DATA_DIR`应传入G:\。
|
||||
- 范例
|
||||
|
||||
上述python命令将在后台运行,您可以通过train.log文件查看结果。
|
||||
|
||||
### 分布式训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
|
||||
```shell
|
||||
# 运行推理示例(EDSR(x2) in the paper)
|
||||
bash scripts/run_infer_310.sh ./mindir/EDSR_x2_DIV2K-6000_50_InputSize1020.mindir ./DIV2K 2 ./infer_x2.log 0
|
||||
# 运行推理示例(EDSR(x3) in the paper)
|
||||
bash scripts/run_infer_310.sh ./mindir/EDSR_x3_DIV2K-6000_50_InputSize680.mindir ./DIV2K 3 ./infer_x3.log 0
|
||||
# 运行推理示例(EDSR(x4) in the paper)
|
||||
bash scripts/run_infer_310.sh ./mindir/EDSR_x4_DIV2K-6000_50_InputSize510.mindir ./DIV2K 4 ./infer_x4.log 0
|
||||
```
|
||||
|
||||
如果数据集保存路径为G:\DIV2K,`TRAIN_DATA_DIR`应传入G:\。
|
||||
- 推理指标,分别查看infer_x2.log、infer_x3.log、infer_x4.log可以看到
|
||||
|
||||
## 评估过程
|
||||
```python
|
||||
# EDSR(x2) in the paper
|
||||
evaluation result = {'psnr': 35.068791459971266}
|
||||
# EDSR(x3) in the paper
|
||||
evaluation result = {'psnr': 31.386362838892456}
|
||||
# EDSR(x4) in the paper
|
||||
evaluation result = {'psnr': 29.38072897971985}
|
||||
```
|
||||
|
||||
### 评估
|
||||
#### 在昇腾310上使用其他数据集进行推理
|
||||
|
||||
在运行以下命令之前,请检查用于评估的检查点路径。
|
||||
- 推理流程
|
||||
|
||||
```bash
|
||||
sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
|
||||
# (1) 整理数据集,lr图片统一padding到一个固定尺寸。参考preprocess.py
|
||||
# (2) 根据固定尺寸导出模型,参考export.py
|
||||
# (3) 使用build.sh在ascend310_infer文件夹内编译推理程序,得到程序ascend310_infer/out/main
|
||||
# (4) 配置数据集图片路径,模型路径,输出路径等,使用main推理得到超分辨率重建图片。
|
||||
./ascend310_infer/out/main --mindir_path=[model] --dataset_path=[read_data_path] --device_id=[device_id] --save_dir=[save_data_path]
|
||||
# (5) 后处理图片,去除padding的无效区域。和hr图一起统计指标。参考preprocess.py
|
||||
```
|
||||
|
||||
`DATASET_TYPE`可选 ["Set5", "Set14", "B100", "Urban100", "DIV2K"]
|
||||
|
||||
如果数据集保存路径为G:\DIV2K或者G:\Set5或者G:\Set14或者G:\B100或者G:\Urban100,`TRAIN_DATA_DIR`应传入G:\。
|
||||
您可以通过log.txt文件查看结果。
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 训练性能
|
||||
|
||||
| 参数 | Ascend |
|
||||
| ------------- | ------------------------------------------------------------ |
|
||||
| 资源 | Ascend 910 |
|
||||
| 上传日期 | 2021-7-4 |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | DIV2K |
|
||||
| 训练参数 | epoch=1000, steps=1000, batch_size =16, lr=0.0001 |
|
||||
| 优化器 | Adam |
|
||||
| 损失函数 | L1 |
|
||||
| 输出 | 超分辨率图片 |
|
||||
| 损失 | 3.1 |
|
||||
| 速度 | 8卡:50.75毫秒/步 |
|
||||
| 总时长 | 8卡:12.865小时 |
|
||||
| 微调检查点 | 466.13 MB (.ckpt文件) |
|
||||
| 脚本 | [EDSR](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/EDSR) |
|
||||
#### DIV2K上的训练2倍/3倍/4倍超分辨率重建的EDSR
|
||||
|
||||
| 参数 | Ascend | Ascend | Ascend |
|
||||
| --- | --- | --- | --- |
|
||||
| 模型版本 | EDSR(x2) | EDSR(x3) | EDSR(x4) |
|
||||
| 资源 | Ascend 910;CPU 2.60GHz,192核;内存 755G;系统 Euler2.8 | Ascend 910;CPU 2.60GHz,192核;内存 755G;系统 Euler2.8 | Ascend 910;CPU 2.60GHz,192核;内存 755G;系统 Euler2.8 |
|
||||
| 上传日期 | 2021-09-01 | 2021-09-01 | 2021-09-01 |
|
||||
| MindSpore版本 | 1.2.0 | 1.2.0 | 1.2.0 |
|
||||
| 数据集 | DIV2K | DIV2K | DIV2K |
|
||||
| 训练参数 | epoch=6000, 总batch_size=16, lr=0.0001, patch_size=192 | epoch=6000, 总batch_size=16, lr=0.0001, patch_size=192 | epoch=6000, 总batch_size=16, lr=0.0001, patch_size=192 |
|
||||
| 优化器 | Adam | Adam | Adam |
|
||||
| 损失函数 | L1 | L1 | L1 |
|
||||
| 输出 | 超分辨率重建RGB图 | 超分辨率重建RGB图 | 超分辨率重建RGB图 |
|
||||
| 损失 | 4.06 | 4.01 | 4.50 |
|
||||
| 速度 | 1卡:16.5秒/epoch;8卡:2.76秒/epoch | 1卡:21.6秒/epoch;8卡:1.8秒/epoch | 1卡:21.0秒/epoch;8卡:1.8秒/epoch |
|
||||
| 总时长 | 单卡:1725分钟; 8卡:310分钟 | 单卡:2234分钟; 8卡:217分钟 | 单卡:2173分钟; 8卡:210分钟 |
|
||||
| 参数(M) | 40.73M | 43.68M | 43.09M |
|
||||
| 微调检查点 | 467.28 MB (.ckpt文件) | 501.04 MB (.ckpt文件) | 494.29 MB (.ckpt文件) |
|
||||
|
||||
### 评估性能
|
||||
|
||||
| 参数 | Ascend |
|
||||
| ------------- | ----------------------------------------------------------- |
|
||||
| 资源 | Ascend 910 |
|
||||
| 上传日期 | 2021-7-4 |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | Set5,Set14,B100,Urban100 |
|
||||
| batch_size | 1 |
|
||||
| 输出 | 超分辨率图片 |
|
||||
| PSNR | Set5:38.2136, Set14:34.0081, B100:32.3590, Urban100:33.0162 |
|
||||
#### DIV2K上的评估2倍/3倍/4倍超分辨率重建的EDSR
|
||||
|
||||
| 参数 | Ascend | Ascend | Ascend |
|
||||
| --- | --- | --- | --- |
|
||||
| 模型版本 | EDSR(x2) | EDSR(x3) | EDSR(x4) |
|
||||
| 资源 | Ascend 910;系统 Euler2.8 | Ascend 910;系统 Euler2.8 | Ascend 910;系统 Euler2.8 |
|
||||
| 上传日期 | 2021-09-01 | 2021-09-01 | 2021-09-01 |
|
||||
| MindSpore版本 | 1.2.0 | 1.2.0 | 1.2.0 |
|
||||
| 数据集 | DIV2K, 100张图像 | DIV2K, 100张图像 | DIV2K, 100张图像 |
|
||||
| self_ensemble | True | True | True |
|
||||
| batch_size | 1 | 1 | 1 |
|
||||
| 输出 | 超分辨率重建RGB图 | 超分辨率重建RGB图 | 超分辨率重建RGB图 |
|
||||
| Set5 psnr | 38.275 db | 34.777 db | 32.618 db |
|
||||
| Set14 psnr | 34.059 db | 30.684 db | 28.928 db |
|
||||
| B100 psnr | 32.393 db | 29.332 db | 27.792 db |
|
||||
| Urban100 psnr | 32.970 db | 29.019 db | 26.849 db |
|
||||
| DIV2K psnr | 35.063 db | 31.380 db | 29.370 db |
|
||||
| 推理模型 | 467.28 MB (.ckpt文件) | 501.04 MB (.ckpt文件) | 494.29 MB (.ckpt文件) |
|
||||
|
||||
### 推理性能
|
||||
|
||||
#### DIV2K上的推理2倍/3倍/4倍超分辨率重建的EDSR
|
||||
|
||||
| 参数 | Ascend | Ascend | Ascend |
|
||||
| --- | --- | --- | --- |
|
||||
| 模型版本 | EDSR(x2) | EDSR(x3) | EDSR(x4) |
|
||||
| 资源 | Ascend 310;系统 ubuntu18.04 | Ascend 310;系统 ubuntu18.04 | Ascend 310;系统 ubuntu18.04 |
|
||||
| 上传日期 | 2021-09-01 | 2021-09-01 | 2021-09-01 |
|
||||
| MindSpore版本 | 1.2.0 | 1.2.0 | 1.2.0 |
|
||||
| 数据集 | DIV2K, 100张图像 | DIV2K, 100张图像 | DIV2K, 100张图像 |
|
||||
| self_ensemble | True | True | True |
|
||||
| batch_size | 1 | 1 | 1 |
|
||||
| 输出 | 超分辨率重建RGB图 | 超分辨率重建RGB图 | 超分辨率重建RGB图 |
|
||||
| DIV2K psnr | 35.068 db | 31.386 db | 29.380 db |
|
||||
| 推理模型 | 156 MB (.mindir文件) | 167 MB (.mindir文件) | 165 MB (.mindir文件) |
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
在train.py中,我们设置了“train_net”函数内的种子。
|
||||
在train.py,eval.py中,我们设置了mindspore.common.set_seed(2021)种子。
|
||||
|
||||
# ModelZoo主页
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
cmake_minimum_required(VERSION 3.14.1)
|
||||
project(Ascend310Infer)
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||
option(MINDSPORE_PATH "mindspore install path" "")
|
||||
include_directories(${MINDSPORE_PATH})
|
||||
include_directories(${MINDSPORE_PATH}/include)
|
||||
include_directories(${PROJECT_SRC_ROOT})
|
||||
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||
|
||||
add_executable(main src/main.cc src/utils.cc)
|
||||
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ ! -d out ]; then
|
||||
mkdir out
|
||||
fi
|
||||
cd out || exit
|
||||
cmake .. \
|
||||
-DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||
make
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <dirent.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||
DIR *OpenDir(std::string_view dirName);
|
||||
std::string RealPath(std::string_view path);
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs,
|
||||
const std::string &homePath);
|
||||
#endif
|
|
@ -0,0 +1,141 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <sys/time.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <dirent.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iosfwd>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/serialization.h"
|
||||
#include "include/dataset/vision_ascend.h"
|
||||
#include "include/dataset/execute.h"
|
||||
#include "include/dataset/vision.h"
|
||||
#include "inc/utils.h"
|
||||
|
||||
using mindspore::Context;
|
||||
using mindspore::Serialization;
|
||||
using mindspore::Model;
|
||||
using mindspore::Status;
|
||||
using mindspore::ModelType;
|
||||
using mindspore::GraphCell;
|
||||
using mindspore::kSuccess;
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::dataset::Execute;
|
||||
using mindspore::dataset::MapTargetDevice;
|
||||
using mindspore::dataset::TensorTransform;
|
||||
using mindspore::dataset::vision::Resize;
|
||||
using mindspore::dataset::vision::HWC2CHW;
|
||||
using mindspore::dataset::vision::Normalize;
|
||||
using mindspore::dataset::vision::Decode;
|
||||
using mindspore::dataset::vision::CenterCrop;
|
||||
|
||||
DEFINE_string(mindir_path, "", "mindir path");
|
||||
DEFINE_string(dataset_path, ".", "dataset path");
|
||||
DEFINE_string(save_dir, "", "save dir");
|
||||
DEFINE_int32(device_id, 0, "device id");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
if (RealPath(FLAGS_mindir_path).empty()) {
|
||||
std::cout << "Invalid mindir" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
DIR *dir = OpenDir(FLAGS_save_dir);
|
||||
if (dir == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto context = std::make_shared<Context>();
|
||||
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||
ascend310->SetDeviceID(FLAGS_device_id);
|
||||
ascend310->SetBufferOptimizeMode("off_optimize");
|
||||
context->MutableDeviceInfo().push_back(ascend310);
|
||||
mindspore::Graph graph;
|
||||
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
|
||||
|
||||
Model model;
|
||||
Status ret = model.Build(GraphCell(graph), context);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Build failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto decode = Decode();
|
||||
auto normalize = Normalize({0.0, 0.0, 0.0}, {1.0, 1.0, 1.0});
|
||||
auto hwc2chw = HWC2CHW();
|
||||
Execute transform({decode, normalize, hwc2chw});
|
||||
|
||||
auto all_files = GetAllFiles(FLAGS_dataset_path);
|
||||
std::map<double, double> costTime_map;
|
||||
size_t size = all_files.size();
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
struct timeval start = {0};
|
||||
struct timeval end = {0};
|
||||
double startTimeMs = 0.0;
|
||||
double endTimeMs = 0.0;
|
||||
std::vector<MSTensor> inputs;
|
||||
std::vector<MSTensor> outputs;
|
||||
std::cout << "Start predict input files:" << all_files[i] << std::endl;
|
||||
auto img = MSTensor();
|
||||
auto image = ReadFileToTensor(all_files[i]);
|
||||
transform(image, &img);
|
||||
std::vector<MSTensor> model_inputs = model.GetInputs();
|
||||
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
|
||||
img.Data().get(), img.DataSize());
|
||||
gettimeofday(&start, nullptr);
|
||||
ret = model.Predict(inputs, &outputs);
|
||||
gettimeofday(&end, nullptr);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "Predict " << all_files[i] << " failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
|
||||
WriteResult(all_files[i], outputs, FLAGS_save_dir);
|
||||
}
|
||||
double average = 0.0;
|
||||
int inferCount = 0;
|
||||
|
||||
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||
double diff = 0.0;
|
||||
diff = iter->second - iter->first;
|
||||
average += diff;
|
||||
inferCount++;
|
||||
}
|
||||
|
||||
average = average / inferCount;
|
||||
std::stringstream timeCost;
|
||||
timeCost << "NN inference cost average time: " << average << " ms of infer_count " << inferCount << std::endl;
|
||||
std::cout << "NN inference cost average time: " << average << "ms of infer_count " << inferCount << std::endl;
|
||||
std::string fileName = FLAGS_save_dir + std::string("/test_perform_static.txt");
|
||||
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
|
||||
fileStream << timeCost.str();
|
||||
fileStream.close();
|
||||
costTime_map.clear();
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/utils.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::DataType;
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||
struct dirent *filename;
|
||||
DIR *dir = OpenDir(dirName);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
std::vector<std::string> dirs;
|
||||
std::vector<std::string> files;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string dName = std::string(filename->d_name);
|
||||
if (dName == "." || dName == "..") {
|
||||
continue;
|
||||
} else if (filename->d_type == DT_DIR) {
|
||||
dirs.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||
} else if (filename->d_type == DT_REG) {
|
||||
files.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto d : dirs) {
|
||||
dir = OpenDir(d);
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string dName = std::string(filename->d_name);
|
||||
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||
continue;
|
||||
}
|
||||
files.emplace_back(std::string(d) + "/" + filename->d_name);
|
||||
}
|
||||
}
|
||||
std::sort(files.begin(), files.end());
|
||||
for (auto &f : files) {
|
||||
std::cout << "image file: " << f << std::endl;
|
||||
}
|
||||
return files;
|
||||
}
|
||||
|
||||
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs, const std::string &homePath) {
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
std::shared_ptr<const void> netOutput;
|
||||
netOutput = outputs[i].Data();
|
||||
size_t outputSize = outputs[i].DataSize();
|
||||
int pos = imageFile.rfind('/');
|
||||
std::string fileName(imageFile, pos + 1);
|
||||
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
|
||||
std::string outFileName = homePath + "/" + fileName;
|
||||
FILE * outputFile = fopen(outFileName.c_str(), "wb");
|
||||
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
|
||||
fclose(outputFile);
|
||||
outputFile = nullptr;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
std::cout << "Pointer file is nullptr" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cout << "File: " << file << "open failed" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
mindspore::MSTensor buffer(
|
||||
file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
DIR *OpenDir(std::string_view dirName) {
|
||||
if (dirName.empty()) {
|
||||
std::cout << " dirName is null ! " << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::string realPath = RealPath(dirName);
|
||||
struct stat s;
|
||||
lstat(realPath.c_str(), &s);
|
||||
if (!S_ISDIR(s.st_mode)) {
|
||||
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
DIR *dir = opendir(realPath.c_str());
|
||||
if (dir == nullptr) {
|
||||
std::cout << "Can not open dir " << dirName << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||
return dir;
|
||||
}
|
||||
|
||||
std::string RealPath(std::string_view path) {
|
||||
char realPathMem[PATH_MAX] = {0};
|
||||
char *realPathRet = nullptr;
|
||||
realPathRet = realpath(path.data(), realPathMem);
|
||||
|
||||
if (realPathRet == nullptr) {
|
||||
std::cout << "File: " << path << " is not exist.";
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string realPath(realPathMem);
|
||||
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||
return realPath;
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
device_target: "Ascend"
|
||||
|
||||
# ==============================================================================
|
||||
ckpt_save_dir: "./ckpt/"
|
||||
|
||||
self_ensemble: True
|
||||
save_sr: True
|
||||
|
||||
# dataset options
|
||||
dataset_name: "benchmark"
|
||||
scale: 4
|
||||
need_unzip_in_modelarts: False
|
||||
|
||||
# net options
|
||||
pre_trained: ""
|
||||
rgb_range: 255
|
||||
rgb_mean: [0.4488, 0.4371, 0.4040]
|
||||
rgb_std: [1.0, 1.0, 1.0]
|
||||
n_colors: 3
|
||||
n_feats: 256
|
||||
kernel_size: 3
|
||||
n_resblocks: 32
|
||||
res_scale: 0.1
|
||||
|
||||
---
|
||||
# helper
|
||||
|
||||
enable_modelarts: "set True if run in modelarts, default: False"
|
||||
# Url for modelarts
|
||||
data_url: "modelarts data path"
|
||||
train_url: "modelarts code path"
|
||||
checkpoint_url: "modelarts checkpoint save path"
|
||||
# Path for local
|
||||
data_path: "local data path, data will be download from 'data_url', default: /cache/data"
|
||||
output_path: "local output path, checkpoint will be upload to 'checkpoint_url', default: /cache/train"
|
||||
device_target: "choice from ['Ascend'], default: Ascend"
|
||||
|
||||
# ==============================================================================
|
||||
# train options
|
||||
ckpt_save_dir: "the relative path to save checkpoint, root path is 'output_path', defalue: ./ckpt/"
|
||||
|
||||
self_ensemble: "set True if wanna do self-ensemble while evaluating, defalue: True"
|
||||
save_sr: "set True if wanna save sr and hr image while evaluating, defalue: True"
|
||||
|
||||
# dataset options
|
||||
dataset_name: "dataset name, defalue: DIV2K"
|
||||
scale: "scale for super resolution reconstruction, choice from [2,3,4], defalue: 4"
|
||||
need_unzip_in_modelarts: "set True if wanna unzip data after download data from s3, defalue: False"
|
||||
need_unzip_files: "list of zip files to unzip, only work while 'need_unzip_in_modelarts'=True"
|
||||
|
||||
# net options
|
||||
pre_trained: "load pre-trained model, x2/x3/x4 models can be loaded for each other, choice from [[S3_ABS_PATH], [RELATIVE_PATH below 'output_path'], [LOCAL_ABS_PATH], ''], defalue: ''"
|
||||
rgb_range: "pix value range, defalue: 255"
|
||||
rgb_mean: "rgb mean, defalue: [0.4488, 0.4371, 0.4040]"
|
||||
rgb_std: "rgb standard deviation, defalue: [1.0, 1.0, 1.0]"
|
||||
n_colors: "the number of RGB image channels, defalue: 3"
|
||||
n_feats: "the number of output features for each Conv2d, defalue: 256"
|
||||
kernel_size: "kernel size for Conv2d, defalue: 3"
|
||||
n_resblocks: "the number of resblocks, defalue: 32"
|
||||
res_scale: "zoom scale of res branch, defalue: 0.1"
|
|
@ -12,64 +12,189 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""edsr eval script"""
|
||||
"""
|
||||
#################evaluate EDSR example on DIV2K########################
|
||||
"""
|
||||
import os
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.args import args
|
||||
import src.model as edsr
|
||||
from src.data.srdata import SRData
|
||||
from src.data.div2k import DIV2K
|
||||
from src.metrics import calc_psnr, quantize, calc_ssim
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
|
||||
context.set_context(max_call_depth=10000)
|
||||
def eval_net():
|
||||
"""eval"""
|
||||
if args.epochs == 0:
|
||||
args.epochs = 1e8
|
||||
for arg in vars(args):
|
||||
if vars(args)[arg] == 'True':
|
||||
vars(args)[arg] = True
|
||||
elif vars(args)[arg] == 'False':
|
||||
vars(args)[arg] = False
|
||||
if args.data_test[0] == 'DIV2K':
|
||||
train_dataset = DIV2K(args, name=args.data_test, train=False, benchmark=False)
|
||||
else:
|
||||
train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False)
|
||||
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
|
||||
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
|
||||
net_m = edsr.EDSR(args)
|
||||
if args.ckpt_path:
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(net_m, param_dict)
|
||||
net_m.set_train(False)
|
||||
|
||||
print('load mindspore net successfully.')
|
||||
num_imgs = train_de_dataset.get_dataset_size()
|
||||
psnrs = np.zeros((num_imgs, 1))
|
||||
ssims = np.zeros((num_imgs, 1))
|
||||
for batch_idx, imgs in enumerate(train_loader):
|
||||
lr = imgs['LR']
|
||||
hr = imgs['HR']
|
||||
lr = Tensor(lr, mstype.float32)
|
||||
pred = net_m(lr)
|
||||
pred_np = pred.asnumpy()
|
||||
pred_np = quantize(pred_np, 255)
|
||||
psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0)
|
||||
pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0)
|
||||
hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0)
|
||||
ssim = calc_ssim(pred_np, hr, args.scale[0])
|
||||
print("current psnr: ", psnr)
|
||||
print("current ssim: ", ssim)
|
||||
psnrs[batch_idx, 0] = psnr
|
||||
ssims[batch_idx, 0] = ssim
|
||||
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
|
||||
print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0]))
|
||||
import numpy as np
|
||||
from mindspore.common import set_seed
|
||||
from mindspore import Tensor, ops
|
||||
|
||||
from src.metric import SelfEnsembleWrapperNumpy, PSNR, SaveSrHr
|
||||
from src.utils import init_env, init_dataset, init_net, modelarts_pre_process, do_eval
|
||||
from src.dataset import get_rank_info, LrHrImages, hwc2chw, uint8_to_float32
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
|
||||
set_seed(2021)
|
||||
|
||||
|
||||
class HrCutter:
|
||||
"""
|
||||
cut hr into correct shape, for evaluating benchmark
|
||||
"""
|
||||
def __init__(self, lr_scale):
|
||||
self.lr_scale = lr_scale
|
||||
|
||||
def __call__(self, lr, hr):
|
||||
lrh, lrw, _ = lr.shape
|
||||
hrh, hrw, _ = hr.shape
|
||||
h, w = lrh * self.lr_scale, lrw * self.lr_scale
|
||||
if hrh != h or hrw != w:
|
||||
hr = hr[0:h, 0:w, :]
|
||||
return lr, hr
|
||||
|
||||
|
||||
class RepeatDataSet:
|
||||
"""
|
||||
Repeat DataSet so that it can dist evaluate Set5
|
||||
"""
|
||||
def __init__(self, dataset, repeat):
|
||||
self.dataset = dataset
|
||||
self.repeat = repeat
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset[idx % len(self.dataset)]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset) * self.repeat
|
||||
|
||||
|
||||
def create_dataset_benchmark(dataset_path, scale):
|
||||
"""
|
||||
create a train or eval benchmark dataset
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
scale(int): lr scale, read data ordered by it, choices=(2,3,4)
|
||||
Returns:
|
||||
multi_datasets
|
||||
"""
|
||||
lr_scale = scale
|
||||
|
||||
multi_datasets = {}
|
||||
for dataset_name in ["Set5", "Set14", "B100", "Urban100"]:
|
||||
# get HR_PATH/*.png
|
||||
dir_hr = os.path.join(dataset_path, dataset_name, "HR")
|
||||
hr_pattern = os.path.join(dir_hr, "*.png")
|
||||
|
||||
# get LR
|
||||
column_names = [f"lrx{lr_scale}", "hr"]
|
||||
dir_lr = os.path.join(dataset_path, dataset_name, "LR_bicubic", f"X{lr_scale}")
|
||||
lr_pattern = os.path.join(dir_lr, f"*x{lr_scale}.png")
|
||||
lrs_pattern = [lr_pattern]
|
||||
|
||||
device_num, rank_id = get_rank_info()
|
||||
|
||||
# make dataset
|
||||
dataset = LrHrImages(lr_pattern=lrs_pattern, hr_pattern=hr_pattern)
|
||||
if len(dataset) < device_num:
|
||||
dataset = RepeatDataSet(dataset, repeat=device_num // len(dataset) + 1)
|
||||
|
||||
# make mindspore dataset
|
||||
if device_num == 1 or device_num is None:
|
||||
generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names,
|
||||
num_parallel_workers=3,
|
||||
shuffle=False)
|
||||
else:
|
||||
sampler = ds.DistributedSampler(num_shards=device_num, shard_id=rank_id, shuffle=False, offset=0)
|
||||
generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names,
|
||||
num_parallel_workers=3,
|
||||
sampler=sampler)
|
||||
|
||||
# define map operations
|
||||
transform_img = [
|
||||
HrCutter(lr_scale),
|
||||
hwc2chw,
|
||||
uint8_to_float32,
|
||||
]
|
||||
|
||||
# pre-process hr lr
|
||||
generator_dataset = generator_dataset.map(input_columns=column_names,
|
||||
output_columns=column_names,
|
||||
column_order=column_names,
|
||||
operations=transform_img)
|
||||
|
||||
# apply batch operations
|
||||
generator_dataset = generator_dataset.batch(1, drop_remainder=False)
|
||||
|
||||
multi_datasets[dataset_name] = generator_dataset
|
||||
return multi_datasets
|
||||
|
||||
|
||||
class BenchmarkPSNR(PSNR):
|
||||
"""
|
||||
eval psnr for Benchmark
|
||||
"""
|
||||
def __init__(self, rgb_range, shave, channels_scale):
|
||||
super(BenchmarkPSNR, self).__init__(rgb_range=rgb_range, shave=shave)
|
||||
self.channels_scale = channels_scale
|
||||
self.c_scale = Tensor(np.array(self.channels_scale, dtype=np.float32).reshape((1, -1, 1, 1)))
|
||||
self.sum = ops.ReduceSum(keep_dims=True)
|
||||
|
||||
def update(self, *inputs):
|
||||
if len(inputs) != 2:
|
||||
raise ValueError('PSNR need 2 inputs (sr, hr), but got {}'.format(len(inputs)))
|
||||
sr, hr = inputs
|
||||
sr = self.quantize(sr)
|
||||
diff = (sr - hr) / self.rgb_range
|
||||
diff = diff * self.c_scale
|
||||
valid = self.sum(diff, 1)
|
||||
if self.shave is not None and self.shave != 0:
|
||||
valid = valid[..., self.shave:(-self.shave), self.shave:(-self.shave)]
|
||||
mse_list = (valid ** 2).mean(axis=(1, 2, 3))
|
||||
mse_list = self._convert_data(mse_list).tolist()
|
||||
psnr_list = [float(1e32) if mse == 0 else(- 10.0 * math.log10(mse)) for mse in mse_list]
|
||||
self._accumulate(psnr_list)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_eval():
|
||||
"""
|
||||
run eval
|
||||
"""
|
||||
print(config)
|
||||
cfg = config
|
||||
|
||||
init_env(cfg)
|
||||
net = init_net(cfg)
|
||||
eval_net = SelfEnsembleWrapperNumpy(net) if cfg.self_ensemble else net
|
||||
|
||||
if cfg.dataset_name == "DIV2K":
|
||||
cfg.batch_size = 1
|
||||
cfg.patch_size = -1
|
||||
ds_val = init_dataset(cfg, "valid")
|
||||
metrics = {
|
||||
"psnr": PSNR(rgb_range=cfg.rgb_range, shave=6 + cfg.scale),
|
||||
}
|
||||
if config.save_sr:
|
||||
save_img_dir = os.path.join(cfg.output_path, "HrSr")
|
||||
os.makedirs(save_img_dir, exist_ok=True)
|
||||
metrics["num_sr"] = SaveSrHr(save_img_dir)
|
||||
do_eval(eval_net, ds_val, metrics)
|
||||
print("eval success", flush=True)
|
||||
|
||||
elif cfg.dataset_name == "benchmark":
|
||||
multi_datasets = create_dataset_benchmark(dataset_path=cfg.data_path, scale=cfg.scale)
|
||||
result = {}
|
||||
for dname, ds_val in multi_datasets.items():
|
||||
dpnsr = f"{dname}_psnr"
|
||||
gray_coeffs = [65.738, 129.057, 25.064]
|
||||
channels_scale = [x / 256.0 for x in gray_coeffs]
|
||||
metrics = {
|
||||
dpnsr: BenchmarkPSNR(rgb_range=cfg.rgb_range, shave=cfg.scale, channels_scale=channels_scale)
|
||||
}
|
||||
if config.save_sr:
|
||||
save_img_dir = os.path.join(cfg.output_path, "HrSr", dname)
|
||||
os.makedirs(save_img_dir, exist_ok=True)
|
||||
metrics["num_sr"] = SaveSrHr(save_img_dir)
|
||||
result[dpnsr] = do_eval(eval_net, ds_val, metrics)[dpnsr]
|
||||
if get_rank_id() == 0:
|
||||
print(result, flush=True)
|
||||
print("eval success", flush=True)
|
||||
else:
|
||||
raise RuntimeError("Unsupported dataset.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("Start eval function!")
|
||||
eval_net()
|
||||
run_eval()
|
||||
|
|
|
@ -12,50 +12,53 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export net together with checkpoint into air/mindir/onnx models"""
|
||||
"""
|
||||
##############export checkpoint file into air, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
import src.model as edsr
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor, export, context
|
||||
|
||||
parser = argparse.ArgumentParser(description='edsr export')
|
||||
parser.add_argument("--ckpt_path", type=str, required=True, help="path of checkpoint file")
|
||||
parser.add_argument("--file_name", type=str, default="edsr", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, default="MINDIR", choices=['MINDIR', 'AIR', 'ONNX'], help="file format")
|
||||
|
||||
parser.add_argument('--scale', type=str, default='2', help='super resolution scale')
|
||||
parser.add_argument('--rgb_range', type=int, default=255, help='maximum value of RGB')
|
||||
parser.add_argument('--n_colors', type=int, default=3, help='number of color channels to use')
|
||||
parser.add_argument('--n_resblocks', type=int, default=32, help='number of residual blocks')
|
||||
parser.add_argument('--n_feats', type=int, default=256, help='number of feature maps')
|
||||
parser.add_argument('--res_scale', type=float, default=0.1, help='residual scaling')
|
||||
parser.add_argument('--task_id', type=int, default=0)
|
||||
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
|
||||
args1 = parser.parse_args()
|
||||
args1.scale = [int(x) for x in args1.scale.split("+")]
|
||||
for arg in vars(args1):
|
||||
if vars(args1)[arg] == 'True':
|
||||
vars(args1)[arg] = True
|
||||
elif vars(args1)[arg] == 'False':
|
||||
vars(args1)[arg] = False
|
||||
|
||||
def run_export(args):
|
||||
"""run_export"""
|
||||
device_id = int(os.getenv("DEVICE_ID", '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||
net = edsr.EDSR(args)
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
print('load mindspore net and checkpoint successfully.')
|
||||
inputs = Tensor(np.zeros([args.batch_size, 3, 678, 1020], np.float32))
|
||||
export(net, inputs, file_name=args.file_name, file_format=args.file_format)
|
||||
print('export successfully!')
|
||||
from src.utils import init_net
|
||||
from model_utils.config import config
|
||||
from model_utils.device_adapter import get_device_id
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_export(args1)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
MAX_HR_SIZE = 2040
|
||||
|
||||
|
||||
@moxing_wrapper()
|
||||
def run_export():
|
||||
"""
|
||||
run export
|
||||
"""
|
||||
print(config)
|
||||
|
||||
cfg = config
|
||||
if cfg.pre_trained is None:
|
||||
raise RuntimeError('config.pre_trained is None.')
|
||||
|
||||
net = init_net(cfg)
|
||||
max_lr_size = MAX_HR_SIZE // cfg.scale
|
||||
input_arr = Tensor(np.ones([1, cfg.n_colors, max_lr_size, max_lr_size]), ms.float32)
|
||||
file_name = os.path.splitext(os.path.basename(cfg.pre_trained))[0]
|
||||
file_name = file_name + f"_InputSize{max_lr_size}"
|
||||
file_path = os.path.join(cfg.output_path, file_name)
|
||||
file_format = 'MINDIR'
|
||||
|
||||
num_params = sum([param.size for param in net.parameters_dict().values()])
|
||||
export(net, input_arr, file_name=file_path, file_format=file_format)
|
||||
print(f"export success", flush=True)
|
||||
print(f"{cfg.pre_trained} -> {file_path}.{file_format.lower()}, net parameters = {num_params/1000000:>0.4}M",
|
||||
flush=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_export()
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.edsr import EDSR
|
||||
|
||||
|
||||
def edsr(*args, **kwargs):
|
||||
return EDSR(*args, **kwargs)
|
||||
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == "edsr":
|
||||
return edsr(*args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pformat
|
||||
import yaml
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
final_config = Config(final_config)
|
||||
return final_config
|
||||
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
# Run the main function
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -0,0 +1,191 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''post process for 310 inference'''
|
||||
import os
|
||||
import math
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
|
||||
from src.utils import init_env, modelarts_pre_process
|
||||
from src.dataset import FolderImagePair, AUG_DICT
|
||||
from src.metric import PSNR
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
|
||||
def read_bin(bin_path):
|
||||
img = np.fromfile(bin_path, dtype=np.float32)
|
||||
num_pix = img.size
|
||||
img_shape = int(math.sqrt(num_pix // 3))
|
||||
if 1 * 3 * img_shape * img_shape != num_pix:
|
||||
raise RuntimeError(f'bin file error, it not output from edsr network, {bin_path}')
|
||||
img = img.reshape(1, 3, img_shape, img_shape)
|
||||
return img
|
||||
|
||||
|
||||
def read_bin_as_hwc(bin_path):
|
||||
nchw_img = read_bin(bin_path)
|
||||
chw_img = np.squeeze(nchw_img)
|
||||
hwc_img = chw_img.transpose(1, 2, 0)
|
||||
return hwc_img
|
||||
|
||||
|
||||
def unpadding(img, target_shape):
|
||||
h, w = target_shape[0], target_shape[1]
|
||||
img_h, img_w, _ = img.shape
|
||||
if img_h > h:
|
||||
img = img[:h, :, :]
|
||||
if img_w > w:
|
||||
img = img[:, :w, :]
|
||||
return img
|
||||
|
||||
|
||||
def img_to_tensor(img):
|
||||
img = np.array([img.transpose(2, 0, 1)], np.float32)
|
||||
img = Tensor(img)
|
||||
return img
|
||||
|
||||
|
||||
def float_to_uint8(img):
|
||||
clip_img = np.clip(img, 0, 255)
|
||||
round_img = np.round(clip_img)
|
||||
uint8_img = round_img.astype(np.uint8)
|
||||
return uint8_img
|
||||
|
||||
|
||||
def bin_to_png(cfg):
|
||||
"""
|
||||
bin from ascend310_infer outputs will be covert to png
|
||||
"""
|
||||
dataset_path = cfg.data_path
|
||||
dataset_type = "valid"
|
||||
aug_keys = list(AUG_DICT.keys())
|
||||
lr_scale = cfg.scale
|
||||
|
||||
if cfg.self_ensemble:
|
||||
dir_sr_bin = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR_bin", f"X{lr_scale}")
|
||||
save_sr_se_dir = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR_self_ensemble", f"X{lr_scale}")
|
||||
if os.path.isdir(dir_sr_bin):
|
||||
os.makedirs(save_sr_se_dir, exist_ok=True)
|
||||
bin_patterns = [os.path.join(dir_sr_bin, f"*x{lr_scale}_{a_key}_0.bin") for a_key in aug_keys]
|
||||
dataset = FolderImagePair(bin_patterns, reader=read_bin_as_hwc)
|
||||
for i in range(len(dataset)):
|
||||
img_key = dataset.get_key(i)
|
||||
sr_se_path = os.path.join(save_sr_se_dir, f"{img_key}x{lr_scale}.png")
|
||||
if os.path.isfile(sr_se_path):
|
||||
continue
|
||||
data = dataset[i]
|
||||
img_key, sr_8 = data[0], data[1:]
|
||||
sr = np.zeros_like(sr_8[0], dtype=np.float64)
|
||||
for img, a_key in zip(sr_8, aug_keys):
|
||||
aug = AUG_DICT[a_key]
|
||||
for a in reversed(aug):
|
||||
img = a(img)
|
||||
sr += img
|
||||
sr /= len(sr_8)
|
||||
sr = float_to_uint8(sr)
|
||||
Image.fromarray(sr).save(sr_se_path)
|
||||
print(f"merge sr bin save to {sr_se_path}")
|
||||
return
|
||||
|
||||
if not cfg.self_ensemble:
|
||||
dir_sr_bin = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR_bin", f"X{lr_scale}")
|
||||
save_sr_dir = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR", f"X{lr_scale}")
|
||||
if os.path.isdir(dir_sr_bin):
|
||||
os.makedirs(save_sr_dir, exist_ok=True)
|
||||
bin_patterns = [os.path.join(dir_sr_bin, f"*x{lr_scale}_0_0.bin")]
|
||||
dataset = FolderImagePair(bin_patterns, reader=read_bin_as_hwc)
|
||||
for i in range(len(dataset)):
|
||||
img_key = dataset.get_key(i)
|
||||
sr_path = os.path.join(save_sr_dir, f"{img_key}x{lr_scale}.png")
|
||||
if os.path.isfile(sr_path):
|
||||
continue
|
||||
img_key, sr = dataset[i]
|
||||
sr = float_to_uint8(sr)
|
||||
Image.fromarray(sr).save(sr_path)
|
||||
print(f"merge sr bin save to {sr_path}")
|
||||
return
|
||||
|
||||
|
||||
def get_hr_sr_dataset(cfg):
|
||||
"""
|
||||
make hr sr dataset
|
||||
"""
|
||||
dataset_path = cfg.data_path
|
||||
dataset_type = "valid"
|
||||
lr_scale = cfg.scale
|
||||
|
||||
dir_patterns = []
|
||||
|
||||
# get HR_PATH/*.png
|
||||
dir_hr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_HR")
|
||||
hr_pattern = os.path.join(dir_hr, "*.png")
|
||||
dir_patterns.append(hr_pattern)
|
||||
|
||||
# get LR_PATH/X2/*x2.png, LR_PATH/X3/*x3.png, LR_PATH/X4/*x4.png
|
||||
se = "_self_ensemble" if cfg.self_ensemble else ""
|
||||
|
||||
dir_sr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_SR" + se, f"X{lr_scale}")
|
||||
if not os.path.isdir(dir_sr):
|
||||
raise RuntimeError(f'{dir_sr} is not a dir for saving sr')
|
||||
sr_pattern = os.path.join(dir_sr, f"*x{lr_scale}.png")
|
||||
dir_patterns.append(sr_pattern)
|
||||
|
||||
# make dataset
|
||||
dataset = FolderImagePair(dir_patterns)
|
||||
return dataset
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_post_process():
|
||||
"""
|
||||
run post process
|
||||
"""
|
||||
print(config)
|
||||
cfg = config
|
||||
lr_scale = cfg.scale
|
||||
|
||||
init_env(cfg)
|
||||
|
||||
print("begin to run bin_to_png...")
|
||||
bin_to_png(cfg)
|
||||
print("bin_to_png finish")
|
||||
|
||||
dataset = get_hr_sr_dataset(cfg)
|
||||
|
||||
metrics = {
|
||||
"psnr": PSNR(rgb_range=cfg.rgb_range, shave=6 + lr_scale),
|
||||
}
|
||||
|
||||
total_step = len(dataset)
|
||||
setw = len(str(total_step))
|
||||
for i in range(len(dataset)):
|
||||
_, hr, sr = dataset[i]
|
||||
sr = unpadding(sr, hr.shape)
|
||||
sr = img_to_tensor(sr)
|
||||
hr = img_to_tensor(hr)
|
||||
_ = [m.update(sr, hr) for m in metrics.values()]
|
||||
result = {k: m.eval(sync=False) for k, m in metrics.items()}
|
||||
print(f"[{i+1:>{setw}}/{total_step:>{setw}}] result = {result}", flush=True)
|
||||
result = {k: m.eval(sync=False) for k, m in metrics.items()}
|
||||
print(f"evaluation result = {result}", flush=True)
|
||||
|
||||
print("post_process success", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_post_process()
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''pre process for 310 inference'''
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from src.utils import modelarts_pre_process
|
||||
from src.dataset import FolderImagePair, AUG_DICT
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
MAX_HR_SIZE = 2040
|
||||
|
||||
|
||||
def padding(img, target_shape):
|
||||
h, w = target_shape[0], target_shape[1]
|
||||
img_h, img_w, _ = img.shape
|
||||
dh, dw = h - img_h, w - img_w
|
||||
if dh < 0 or dw < 0:
|
||||
raise RuntimeError(f"target_shape is bigger than img.shape, {target_shape} > {img.shape}")
|
||||
if dh != 0 or dw != 0:
|
||||
img = np.pad(img, ((0, dh), (0, dw), (0, 0)), "constant")
|
||||
return img
|
||||
|
||||
|
||||
def get_lr_dataset(cfg):
|
||||
"""
|
||||
get lr dataset
|
||||
"""
|
||||
dataset_path = cfg.data_path
|
||||
lr_scale = cfg.scale
|
||||
lr_type = cfg.lr_type
|
||||
dataset_type = "valid"
|
||||
self_ensemble = "_self_ensemble" if cfg.self_ensemble else ""
|
||||
|
||||
# get LR_PATH/X2/*x2.png, LR_PATH/X3/*x3.png, LR_PATH/X4/*x4.png
|
||||
lrs_pattern = []
|
||||
dir_lr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_LR_{lr_type}", f"X{lr_scale}")
|
||||
lr_pattern = os.path.join(dir_lr, f"*x{lr_scale}.png")
|
||||
lrs_pattern.append(lr_pattern)
|
||||
save_dir = os.path.join(dataset_path, f"DIV2K_{dataset_type}_LR_{lr_type}_AUG{self_ensemble}", f"X{lr_scale}")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_format = os.path.join(save_dir, "{}" + f"x{lr_scale}" + "_{}.png")
|
||||
|
||||
# make dataset
|
||||
dataset = FolderImagePair(lrs_pattern)
|
||||
|
||||
return dataset, save_format
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_pre_process():
|
||||
"""
|
||||
run pre process
|
||||
"""
|
||||
print(config)
|
||||
cfg = config
|
||||
|
||||
aug_dict = AUG_DICT
|
||||
if not cfg.self_ensemble:
|
||||
aug_dict = {"0": AUG_DICT["0"]}
|
||||
|
||||
dataset, save_format = get_lr_dataset(cfg)
|
||||
for i in range(len(dataset)):
|
||||
img_key = dataset.get_key(i)
|
||||
org_img = None
|
||||
for a_key, aug in aug_dict.items():
|
||||
save_path = save_format.format(img_key, a_key)
|
||||
if os.path.isfile(save_path):
|
||||
continue
|
||||
if org_img is None:
|
||||
_, lr = dataset[i]
|
||||
target_shape = [MAX_HR_SIZE // cfg.scale, MAX_HR_SIZE // cfg.scale]
|
||||
org_img = padding(lr, target_shape)
|
||||
img = org_img.copy()
|
||||
for a in aug:
|
||||
img = a(img)
|
||||
Image.fromarray(img).save(save_path)
|
||||
print(f"[{i+1}/{len(dataset)}]\tsave {save_path}\tshape = {img.shape}", flush=True)
|
||||
|
||||
print("pre_process success", flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_pre_process()
|
|
@ -1,72 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
if [ ! -f $PATH1 ]; then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]; then
|
||||
echo "error: TRAIN_DATA_DIR=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
|
||||
python train.py \
|
||||
--batch_size 2 \
|
||||
--lr 1e-4 \
|
||||
--scale 2 \
|
||||
--task_id 0 \
|
||||
--dir_data $PATH2 \
|
||||
--epochs 1000 \
|
||||
--test_every 8000 \
|
||||
--n_resblocks 32 \
|
||||
--n_feats 256 \
|
||||
--res_scale 0.1 \
|
||||
--patch_size 48 > train.log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -1,59 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]; then
|
||||
echo "Usage: sh run_standalone_train.sh [TRAIN_DATA_DIR]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
if [ -d "train" ]; then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
|
||||
env >env.log
|
||||
|
||||
python train.py \
|
||||
--batch_size 16 \
|
||||
--lr 1e-4 \
|
||||
--scale 2 \
|
||||
--task_id 0 \
|
||||
--dir_data $PATH1 \
|
||||
--epochs 1000 \
|
||||
--test_every 1000 \
|
||||
--n_resblocks 32 \
|
||||
--n_feats 256 \
|
||||
--res_scale 0.1 \
|
||||
--patch_size 48 > train.log 2>&1 &
|
|
@ -1,65 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
DATASET_TYPE=$3
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: TEST_DATA_DIR=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]; then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -d "eval" ]; then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
echo "start evaluation ..."
|
||||
|
||||
python eval.py \
|
||||
--dir_data=${PATH1} \
|
||||
--test_only \
|
||||
--ext img \
|
||||
--ckpt_path=${PATH2} \
|
||||
--task_id 0 \
|
||||
--scale 2 \
|
||||
--data_test=${DATASET_TYPE} \
|
||||
--device_id 0 \
|
||||
--n_resblocks 32 \
|
||||
--n_feats 256 \
|
||||
--res_scale 0.1 > log.txt 2>&1 &
|
|
@ -0,0 +1,55 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -lt 1 ]
|
||||
then
|
||||
echo "Usage: sh scripts/run_eval.sh [RANK_TABLE_FILE] --opt1 opt1_value --opt2 opt2_value ..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
PATH1=$(realpath $1)
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
echo "RANK_TABLE_FILE=${PATH1}"
|
||||
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./eval_parallel$i
|
||||
mkdir ./eval_parallel$i
|
||||
cp -r ./src ./eval_parallel$i
|
||||
cp -r ./model_utils ./eval_parallel$i
|
||||
cp -r ./*.yaml ./eval_parallel$i
|
||||
cp ./eval.py ./eval_parallel$i
|
||||
echo "start evaluation for rank $RANK_ID, device $DEVICE_ID"
|
||||
cd ./eval_parallel$i ||exit
|
||||
env > env.log
|
||||
export args=${*:2}
|
||||
python eval.py $args > eval.log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,151 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [[ $# -lt 3 || $# -gt 5 ]]; then
|
||||
echo "Usage: bash scripts/run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [SCALE] [LOG_FILE] [DEVICE_ID]
|
||||
DEVICE_ID is optional, it can be set by environment variable device_id, default: 0"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
model=$(get_real_path $1)
|
||||
data_path=$(get_real_path $2)
|
||||
scale=$3
|
||||
|
||||
if [[ $scale -ne "2" && $scale -ne "3" && $scale -ne "4" ]]; then
|
||||
echo "[SCALE] should be in [2,3,4]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_file="./run_infer.log"
|
||||
if [ $# -gt 4 ]; then
|
||||
log_file=$4
|
||||
fi
|
||||
log_file=$(get_real_path $log_file)
|
||||
|
||||
|
||||
device_id=0
|
||||
if [ $# == 5 ]; then
|
||||
device_id=$5
|
||||
fi
|
||||
|
||||
self_ensemble="True"
|
||||
|
||||
echo "***************** param *****************"
|
||||
echo "mindir name: "$model
|
||||
echo "dataset path: "$data_path
|
||||
echo "scale: "$scale
|
||||
echo "log file: "$log_file
|
||||
echo "device id: "$device_id
|
||||
echo "self_ensemble: "$self_ensemble
|
||||
echo "***************** param *****************"
|
||||
|
||||
export ASCEND_HOME=/usr/local/Ascend/
|
||||
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||
else
|
||||
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||
fi
|
||||
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
|
||||
function compile_app()
|
||||
{
|
||||
echo "begin to compile app..."
|
||||
cd ./ascend310_infer || exit
|
||||
bash build.sh >> $log_file 2>&1
|
||||
cd -
|
||||
echo "finshi compile app"
|
||||
}
|
||||
|
||||
function preprocess()
|
||||
{
|
||||
echo "begin to preprocess..."
|
||||
export DEVICE_ID=$device_id
|
||||
export RANK_SIZE=1
|
||||
python preprocess.py --data_path=$data_path --config_path=DIV2K_config.yaml --device_target=CPU --scale=$scale --self_ensemble=$self_ensemble >> $log_file 2>&1
|
||||
echo "finshi preprocess"
|
||||
}
|
||||
|
||||
function infer()
|
||||
{
|
||||
echo "begin to infer..."
|
||||
if [ $self_ensemble == "True" ]; then
|
||||
read_data_path=$data_path"/DIV2K_valid_LR_bicubic_AUG_self_ensemble/X"$scale
|
||||
else
|
||||
read_data_path=$data_path"/DIV2K_valid_LR_bicubic_AUG/X"$scale
|
||||
fi
|
||||
save_data_path=$data_path"/DIV2K_valid_SR_bin/X"$scale
|
||||
if [ -d $save_data_path ]; then
|
||||
rm -rf $save_data_path
|
||||
fi
|
||||
mkdir -p $save_data_path
|
||||
./ascend310_infer/out/main --mindir_path=$model --dataset_path=$read_data_path --device_id=$device_id --save_dir=$save_data_path >> $log_file 2>&1
|
||||
echo "finshi infer"
|
||||
}
|
||||
|
||||
function postprocess()
|
||||
{
|
||||
echo "begin to postprocess..."
|
||||
export DEVICE_ID=$device_id
|
||||
export RANK_SIZE=1
|
||||
python postprocess.py --data_path=$data_path --config_path=DIV2K_config.yaml --device_target=CPU --scale=$scale --self_ensemble=$self_ensemble >> $log_file 2>&1
|
||||
echo "finshi postprocess"
|
||||
}
|
||||
|
||||
echo "" > $log_file
|
||||
echo "read the log command: "
|
||||
echo " tail -f $log_file"
|
||||
|
||||
compile_app
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "compile app code failed, check $log_file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
preprocess
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "preprocess code failed, check $log_file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
infer
|
||||
if [ $? -ne 0 ]; then
|
||||
echo " execute inference failed, check $log_file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
postprocess
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "postprocess failed, check $log_file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cat $log_file | tail -n 3 | head -n 1
|
|
@ -0,0 +1,55 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -lt 1 ]
|
||||
then
|
||||
echo "Usage: sh scripts/run_train.sh [RANK_TABLE_FILE] --opt1 opt1_value --opt2 opt2_value ..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
PATH1=$(realpath $1)
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
echo "RANK_TABLE_FILE=${PATH1}"
|
||||
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp -r ./model_utils ./train_parallel$i
|
||||
cp -r ./*.yaml ./train_parallel$i
|
||||
cp ./train.py ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env.log
|
||||
export args=${*:2}
|
||||
python train.py $args > train.log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -1,84 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""args parser"""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='EDSR')
|
||||
# Data specifications
|
||||
parser.add_argument('--dir_data', type=str, default='/cache/data/',
|
||||
help='dataset directory')
|
||||
parser.add_argument('--data_train', type=str, default='DIV2K',
|
||||
help='train dataset name')
|
||||
parser.add_argument('--data_test', type=str, default='DIV2K',
|
||||
help='test dataset name')
|
||||
parser.add_argument('--data_range', type=str, default='1-800/801-900',
|
||||
help='train/test data range')
|
||||
parser.add_argument('--ext', type=str, default='sep',
|
||||
help='dataset file extension')
|
||||
parser.add_argument('--scale', type=str, default='4',
|
||||
help='super resolution scale')
|
||||
parser.add_argument('--patch_size', type=int, default=48,
|
||||
help='input patch size')
|
||||
parser.add_argument('--rgb_range', type=int, default=255,
|
||||
help='maximum value of RGB')
|
||||
parser.add_argument('--n_colors', type=int, default=3,
|
||||
help='number of color channels to use')
|
||||
parser.add_argument('--no_augment', action='store_true',
|
||||
help='do not use data augmentation')
|
||||
# Model specifications
|
||||
parser.add_argument('--model', default='EDSR',
|
||||
help='model name')
|
||||
parser.add_argument('--n_resblocks', type=int, default=32,
|
||||
help='number of residual blocks')
|
||||
parser.add_argument('--n_feats', type=int, default=256,
|
||||
help='number of feature maps')
|
||||
parser.add_argument('--res_scale', type=float, default=0.1,
|
||||
help='residual scaling')
|
||||
# Training specifications
|
||||
parser.add_argument('--test_every', type=int, default=8000,
|
||||
help='do test per every N batches')
|
||||
parser.add_argument('--epochs', type=int, default=1000,
|
||||
help='number of epochs to train')
|
||||
parser.add_argument('--batch_size', type=int, default=2,
|
||||
help='input batch size for training')
|
||||
parser.add_argument('--test_only', action='store_true',
|
||||
help='set this option to test the model')
|
||||
# Optimization specifications
|
||||
parser.add_argument('--lr', type=float, default=1e-4,
|
||||
help='learning rate')
|
||||
parser.add_argument('--loss_scale', type=float, default=1024.0,
|
||||
help='init loss scale')
|
||||
# ckpt specifications
|
||||
parser.add_argument('--ckpt_save_path', type=str, default='./ckpt/',
|
||||
help='path to save ckpt')
|
||||
parser.add_argument('--ckpt_save_interval', type=int, default=10,
|
||||
help='save ckpt frequency, unit is epoch')
|
||||
parser.add_argument('--ckpt_save_max', type=int, default=5,
|
||||
help='max number of saved ckpt')
|
||||
parser.add_argument('--ckpt_path', type=str, default='',
|
||||
help='path of saved ckpt')
|
||||
# alltask
|
||||
parser.add_argument('--task_id', type=int, default=0)
|
||||
|
||||
args, unparsed = parser.parse_known_args()
|
||||
args.scale = [int(x) for x in args.scale.split("+")]
|
||||
args.data_train = args.data_train.split('+')
|
||||
args.data_test = args.data_test.split('+')
|
||||
if args.epochs == 0:
|
||||
args.epochs = 1e4
|
||||
for arg in vars(args):
|
||||
if vars(args)[arg] == 'True':
|
||||
vars(args)[arg] = True
|
||||
elif vars(args)[arg] == 'False':
|
||||
vars(args)[arg] = False
|
|
@ -1,96 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""common"""
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
||||
return nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size,
|
||||
pad_mode='pad',
|
||||
padding=(kernel_size//2), has_bias=bias)
|
||||
|
||||
|
||||
class MeanShift(mindspore.nn.Conv2d):
|
||||
"""MeanShift"""
|
||||
def __init__(
|
||||
self, rgb_range,
|
||||
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1, dtype=mindspore.float32):
|
||||
|
||||
std = mindspore.Tensor(rgb_std, dtype)
|
||||
weight = mindspore.Tensor(np.eye(3), dtype).reshape(
|
||||
3, 3, 1, 1) / std.reshape(3, 1, 1, 1)
|
||||
bias = sign * rgb_range * mindspore.Tensor(rgb_mean, dtype) / std
|
||||
|
||||
super(MeanShift, self).__init__(3, 3, kernel_size=1,
|
||||
has_bias=True, weight_init=weight, bias_init=bias)
|
||||
|
||||
for p in self.get_parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
|
||||
class ResBlock(nn.Cell):
|
||||
"""ResBlock"""
|
||||
def __init__(
|
||||
self, conv, n_feats, kernel_size,
|
||||
bias=True, act=nn.ReLU(), res_scale=1):
|
||||
|
||||
super(ResBlock, self).__init__()
|
||||
m = []
|
||||
for i in range(2):
|
||||
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
|
||||
if i == 0:
|
||||
m.append(act)
|
||||
|
||||
self.body = nn.SequentialCell(m)
|
||||
self.res_scale = res_scale
|
||||
self.mul = mindspore.ops.Mul()
|
||||
|
||||
def construct(self, x):
|
||||
res = self.body(x)
|
||||
res = self.mul(res, self.res_scale)
|
||||
res += x
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class PixelShuffle(nn.Cell):
|
||||
"""PixelShuffle"""
|
||||
def __init__(self, upscale_factor):
|
||||
super(PixelShuffle, self).__init__()
|
||||
self.DepthToSpace = mindspore.ops.DepthToSpace(upscale_factor)
|
||||
|
||||
def construct(self, x):
|
||||
return self.DepthToSpace(x)
|
||||
|
||||
|
||||
def Upsampler(conv, scale, n_feats, bias=True):
|
||||
"""Upsampler"""
|
||||
m = []
|
||||
|
||||
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
|
||||
for _ in range(int(math.log(scale, 2))):
|
||||
m.append(conv(n_feats, 4 * n_feats, 3, bias))
|
||||
m.append(PixelShuffle(2))
|
||||
elif scale == 3:
|
||||
m.append(conv(n_feats, 9 * n_feats, 3, bias))
|
||||
m.append(PixelShuffle(3))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return m
|
|
@ -1,75 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""common"""
|
||||
import random
|
||||
import numpy as np
|
||||
def get_patch(*args, patch_size=96, scale=2, input_large=False):
|
||||
"""common"""
|
||||
ih, iw = args[0].shape[:2]
|
||||
tp = patch_size
|
||||
ip = tp // scale
|
||||
ix = random.randrange(0, iw - ip + 1)
|
||||
iy = random.randrange(0, ih - ip + 1)
|
||||
if not input_large:
|
||||
tx, ty = scale * ix, scale * iy
|
||||
else:
|
||||
tx, ty = ix, iy
|
||||
ret = [args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]]
|
||||
return ret
|
||||
def set_channel(*args, n_channels=3):
|
||||
"""common"""
|
||||
def _set_channel(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
c = img.shape[2]
|
||||
if n_channels == 3 and c == 1:
|
||||
img = np.concatenate([img] * n_channels, 2)
|
||||
return img[:, :, :n_channels]
|
||||
return [_set_channel(a) for a in args]
|
||||
def np2Tensor(*args, rgb_range=255):
|
||||
def _np2Tensor(img):
|
||||
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
|
||||
input_data = np_transpose.astype(np.float32)
|
||||
output = input_data * (rgb_range / 255)
|
||||
return output
|
||||
return [_np2Tensor(a) for a in args]
|
||||
def augment(*args, hflip=True, rot=True):
|
||||
"""common"""
|
||||
hflip = hflip and random.random() < 0.5
|
||||
vflip = rot and random.random() < 0.5
|
||||
rot90 = rot and random.random() < 0.5
|
||||
def _augment(img):
|
||||
"""common"""
|
||||
if hflip:
|
||||
img = img[:, ::-1, :]
|
||||
if vflip:
|
||||
img = img[::-1, :, :]
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
return [_augment(a) for a in args]
|
||||
def search(root, target="JPEG"):
|
||||
"""srdata"""
|
||||
item_list = []
|
||||
items = os.listdir(root)
|
||||
for item in items:
|
||||
path = os.path.join(root, item)
|
||||
if os.path.isdir(path):
|
||||
item_list.extend(search(path, target))
|
||||
elif path.split('/')[-1].startswith(target):
|
||||
item_list.append(path)
|
||||
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
|
||||
item_list.append(path)
|
||||
return item_list
|
|
@ -1,43 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""div2k"""
|
||||
import os
|
||||
from src.data.srdata import SRData
|
||||
class DIV2K(SRData):
|
||||
"""DIV2K"""
|
||||
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
|
||||
data_range = [r.split('-') for r in args.data_range.split('/')]
|
||||
if train:
|
||||
data_range = data_range[0]
|
||||
else:
|
||||
if args.test_only and len(data_range) == 1:
|
||||
data_range = data_range[0]
|
||||
else:
|
||||
data_range = data_range[1]
|
||||
self.begin, self.end = list(map(int, data_range))
|
||||
super(DIV2K, self).__init__(args, name=name, train=train, benchmark=benchmark)
|
||||
self.dir_hr = None
|
||||
self.dir_lr = None
|
||||
|
||||
def _scan(self):
|
||||
names_hr, names_lr = super(DIV2K, self)._scan()
|
||||
names_hr = names_hr[self.begin - 1:self.end]
|
||||
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
|
||||
return names_hr, names_lr
|
||||
|
||||
def _set_filesystem(self, dir_data):
|
||||
super(DIV2K, self)._set_filesystem(dir_data)
|
||||
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
|
||||
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
|
|
@ -1,212 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""""srdata"""
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
import pickle
|
||||
import imageio
|
||||
from src.data import common
|
||||
from PIL import ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
class SRData:
|
||||
"""srdata"""
|
||||
def __init__(self, args, name='', train=True, benchmark=False):
|
||||
self.args = args
|
||||
self.name = name
|
||||
self.train = train
|
||||
self.split = 'train' if train else 'test'
|
||||
self.do_eval = True
|
||||
self.benchmark = benchmark
|
||||
self.input_large = (args.model == 'VDSR')
|
||||
self.scale = args.scale
|
||||
self.idx_scale = 0
|
||||
self._set_filesystem(args.dir_data)
|
||||
self._set_img(args)
|
||||
if train:
|
||||
self._repeat(args)
|
||||
|
||||
def _set_img(self, args):
|
||||
"""srdata"""
|
||||
if args.ext.find('img') < 0:
|
||||
path_bin = os.path.join(self.apath, 'bin')
|
||||
os.makedirs(path_bin, exist_ok=True)
|
||||
list_hr, list_lr = self._scan()
|
||||
if args.ext.find('img') >= 0 or self.benchmark:
|
||||
self.images_hr, self.images_lr = list_hr, list_lr
|
||||
elif args.ext.find('sep') >= 0:
|
||||
os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True)
|
||||
for s in self.scale:
|
||||
if s == 1:
|
||||
os.makedirs(os.path.join(self.dir_hr), exist_ok=True)
|
||||
else:
|
||||
os.makedirs(
|
||||
os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True)
|
||||
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
|
||||
for h in list_hr:
|
||||
b = h.replace(self.apath, path_bin)
|
||||
b = b.replace(self.ext[0], '.pt')
|
||||
self.images_hr.append(b)
|
||||
self._check_and_load(args.ext, h, b, verbose=True)
|
||||
for i, ll in enumerate(list_lr):
|
||||
for l in ll:
|
||||
b = l.replace(self.apath, path_bin)
|
||||
b = b.replace(self.ext[1], '.pt')
|
||||
self.images_lr[i].append(b)
|
||||
self._check_and_load(args.ext, l, b, verbose=True)
|
||||
|
||||
def _repeat(self, args):
|
||||
"""srdata"""
|
||||
n_patches = args.batch_size * args.test_every
|
||||
n_images = len(args.data_train) * len(self.images_hr)
|
||||
if n_images == 0:
|
||||
self.repeat = 0
|
||||
else:
|
||||
self.repeat = max(n_patches // n_images, 1)
|
||||
|
||||
def _scan(self):
|
||||
"""srdata"""
|
||||
names_hr = sorted(
|
||||
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])))
|
||||
names_lr = [[] for _ in self.scale]
|
||||
for f in names_hr:
|
||||
filename, _ = os.path.splitext(os.path.basename(f))
|
||||
for si, s in enumerate(self.scale):
|
||||
if s != 1:
|
||||
scale = s
|
||||
names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \
|
||||
.format(s, filename, scale, self.ext[1])))
|
||||
for si, s in enumerate(self.scale):
|
||||
if s == 1:
|
||||
names_lr[si] = names_hr
|
||||
return names_hr, names_lr
|
||||
|
||||
def _set_filesystem(self, dir_data):
|
||||
self.apath = os.path.join(dir_data, self.name[0])
|
||||
self.dir_hr = os.path.join(self.apath, 'HR')
|
||||
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
|
||||
self.ext = ('.png', '.png')
|
||||
|
||||
def _check_and_load(self, ext, img, f, verbose=True):
|
||||
if not os.path.isfile(f) or ext.find('reset') >= 0:
|
||||
if verbose:
|
||||
print('Making a binary: {}'.format(f))
|
||||
with open(f, 'wb') as _f:
|
||||
pickle.dump(imageio.imread(img), _f)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
lr, hr, _ = self._load_file(idx)
|
||||
pair = self.get_patch(lr, hr)
|
||||
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
|
||||
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
||||
return pair_t[0], pair_t[1]
|
||||
|
||||
def __len__(self):
|
||||
if self.train:
|
||||
return len(self.images_hr) * self.repeat
|
||||
return len(self.images_hr)
|
||||
|
||||
def _get_index(self, idx):
|
||||
if self.train:
|
||||
return idx % len(self.images_hr)
|
||||
return idx
|
||||
|
||||
def _load_file_deblur(self, idx, train=True):
|
||||
"""srdata"""
|
||||
idx = self._get_index(idx)
|
||||
if train:
|
||||
f_hr = self.images_hr[idx]
|
||||
f_lr = self.images_lr[idx]
|
||||
else:
|
||||
f_hr = self.deblur_hr_test[idx]
|
||||
f_lr = self.deblur_lr_test[idx]
|
||||
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
||||
filename = f_hr[-27:-17] + filename
|
||||
hr = imageio.imread(f_hr)
|
||||
lr = imageio.imread(f_lr)
|
||||
return lr, hr, filename
|
||||
|
||||
def _load_file_hr(self, idx):
|
||||
"""srdata"""
|
||||
idx = self._get_index(idx)
|
||||
f_hr = self.images_hr[idx]
|
||||
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
||||
if self.args.ext == 'img' or self.benchmark:
|
||||
hr = imageio.imread(f_hr)
|
||||
elif self.args.ext.find('sep') >= 0:
|
||||
with open(f_hr, 'rb') as _f:
|
||||
hr = pickle.load(_f)
|
||||
return hr, filename
|
||||
|
||||
def _load_rain_test(self, idx):
|
||||
f_hr = self.derain_hr_test[idx]
|
||||
f_lr = self.derain_lr_test[idx]
|
||||
filename, _ = os.path.splitext(os.path.basename(f_lr))
|
||||
norain = imageio.imread(f_hr)
|
||||
rain = imageio.imread(f_lr)
|
||||
return norain, rain, filename
|
||||
|
||||
def _load_file(self, idx):
|
||||
"""srdata"""
|
||||
idx = self._get_index(idx)
|
||||
f_hr = self.images_hr[idx]
|
||||
f_lr = self.images_lr[self.idx_scale][idx]
|
||||
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
||||
if self.args.ext == 'img' or self.benchmark:
|
||||
hr = imageio.imread(f_hr)
|
||||
lr = imageio.imread(f_lr)
|
||||
elif self.args.ext.find('sep') >= 0:
|
||||
with open(f_hr, 'rb') as _f:
|
||||
hr = pickle.load(_f)
|
||||
with open(f_lr, 'rb') as _f:
|
||||
lr = pickle.load(_f)
|
||||
return lr, hr, filename
|
||||
|
||||
def get_patch_hr(self, hr):
|
||||
"""srdata"""
|
||||
if self.train:
|
||||
hr = self.get_patch_img_hr(hr, patch_size=self.args.patch_size, scale=1)
|
||||
return hr
|
||||
|
||||
def get_patch_img_hr(self, img, patch_size=96, scale=2):
|
||||
"""srdata"""
|
||||
ih, iw = img.shape[:2]
|
||||
tp = patch_size
|
||||
ip = tp // scale
|
||||
ix = random.randrange(0, iw - ip + 1)
|
||||
iy = random.randrange(0, ih - ip + 1)
|
||||
ret = img[iy:iy + ip, ix:ix + ip, :]
|
||||
return ret
|
||||
|
||||
def get_patch(self, lr, hr):
|
||||
"""srdata"""
|
||||
scale = self.scale[self.idx_scale]
|
||||
if self.train:
|
||||
lr, hr = common.get_patch(
|
||||
lr, hr,
|
||||
patch_size=self.args.patch_size * scale,
|
||||
scale=scale)
|
||||
if not self.args.no_augment:
|
||||
lr, hr = common.augment(lr, hr)
|
||||
else:
|
||||
ih, iw = lr.shape[:2]
|
||||
hr = hr[0:ih * scale, 0:iw * scale]
|
||||
return lr, hr
|
||||
|
||||
def set_scale(self, idx_scale):
|
||||
if not self.input_large:
|
||||
self.idx_scale = idx_scale
|
||||
else:
|
||||
self.idx_scale = random.randint(0, len(self.scale) - 1)
|
|
@ -0,0 +1,321 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import os
|
||||
import glob
|
||||
import re
|
||||
from functools import reduce
|
||||
import random
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
def get_rank_info():
|
||||
"""
|
||||
get rank size and rank id
|
||||
"""
|
||||
from model_utils.moxing_adapter import get_rank_id, get_device_num
|
||||
return get_device_num(), get_rank_id()
|
||||
|
||||
|
||||
class FolderImagePair:
|
||||
"""
|
||||
get image pair
|
||||
dir_patterns(list): a list of image path patterns. such as ["/LR/*.jpg", "/HR/*.png"...]
|
||||
the file key is matched chars from * and ?
|
||||
reader(object/func): a method to read image by path.
|
||||
"""
|
||||
def __init__(self, dir_patterns, reader=None):
|
||||
self.dir_patterns = dir_patterns
|
||||
self.reader = reader
|
||||
self.pair_keys, self.image_pairs = self.scan_pair(self.dir_patterns)
|
||||
|
||||
@classmethod
|
||||
def scan_pair(cls, dir_patterns):
|
||||
"""
|
||||
scan pair
|
||||
"""
|
||||
images = []
|
||||
for _dir in dir_patterns:
|
||||
imgs = glob.glob(_dir)
|
||||
_dir = os.path.basename(_dir)
|
||||
pat = _dir.replace("*", "(.*)").replace("?", "(.?)")
|
||||
pat = re.compile(pat, re.I | re.M)
|
||||
keys = [re.findall(pat, os.path.basename(p))[0] for p in imgs]
|
||||
images.append({k: v for k, v in zip(keys, imgs)})
|
||||
same_keys = reduce(lambda x, y: set(x) & set(y), images)
|
||||
same_keys = sorted(same_keys)
|
||||
image_pairs = [[d[k] for d in images] for k in same_keys]
|
||||
same_keys = [x if isinstance(x, str) else "_".join(x) for x in same_keys]
|
||||
return same_keys, image_pairs
|
||||
|
||||
def get_key(self, idx):
|
||||
return self.pair_keys[idx]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.reader is None:
|
||||
images = [Image.open(p) for p in self.image_pairs[idx]]
|
||||
images = [img.convert('RGB') for img in images]
|
||||
images = [np.array(img) for img in images]
|
||||
else:
|
||||
images = [self.reader(p) for p in self.image_pairs[idx]]
|
||||
pair_key = self.pair_keys[idx]
|
||||
return (pair_key, *images)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pair_keys)
|
||||
|
||||
|
||||
class LrHrImages(FolderImagePair):
|
||||
"""
|
||||
make LrHrImages dataset
|
||||
"""
|
||||
def __init__(self, lr_pattern, hr_pattern, reader=None):
|
||||
self.hr_pattern = hr_pattern
|
||||
self.lr_pattern = lr_pattern
|
||||
self.dir_patterns = []
|
||||
if isinstance(self.lr_pattern, str):
|
||||
self.is_multi_lr = False
|
||||
self.dir_patterns.append(self.lr_pattern)
|
||||
elif len(lr_pattern) == 1:
|
||||
self.is_multi_lr = False
|
||||
self.dir_patterns.append(self.lr_pattern[0])
|
||||
else:
|
||||
self.is_multi_lr = True
|
||||
self.dir_patterns.extend(self.lr_pattern)
|
||||
self.dir_patterns.append(self.hr_pattern)
|
||||
super(LrHrImages, self).__init__(self.dir_patterns, reader=reader)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
_, *images = super(LrHrImages, self).__getitem__(idx)
|
||||
return images
|
||||
|
||||
class _BasePatchCutter:
|
||||
"""
|
||||
cut patch from images
|
||||
patch_size(int): patch size, input images should be bigger than patch_size.
|
||||
lr_scale(int/list): lr scales for input images. Choice from [1,2,3,4, or their combination]
|
||||
"""
|
||||
def __init__(self, patch_size, lr_scale):
|
||||
self.patch_size = patch_size
|
||||
self.multi_lr_scale = lr_scale
|
||||
if isinstance(lr_scale, int):
|
||||
self.multi_lr_scale = [lr_scale]
|
||||
else:
|
||||
self.multi_lr_scale = [*lr_scale]
|
||||
self.max_lr_scale_idx = self.multi_lr_scale.index(max(self.multi_lr_scale))
|
||||
self.max_lr_scale = self.multi_lr_scale[self.max_lr_scale_idx]
|
||||
|
||||
def get_tx_ty(self, target_height, target_weight, target_patch_size):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __call__(self, *images):
|
||||
target_img = images[self.max_lr_scale_idx]
|
||||
|
||||
tp = self.patch_size // self.max_lr_scale
|
||||
th, tw, _ = target_img.shape
|
||||
|
||||
tx, ty = self.get_tx_ty(th, tw, tp)
|
||||
|
||||
patch_images = []
|
||||
for _, (img, lr_scale) in enumerate(zip(images, self.multi_lr_scale)):
|
||||
x = tx * self.max_lr_scale // lr_scale
|
||||
y = ty * self.max_lr_scale // lr_scale
|
||||
p = tp * self.max_lr_scale // lr_scale
|
||||
patch_images.append(img[y:(y + p), x:(x + p), :])
|
||||
return tuple(patch_images)
|
||||
|
||||
|
||||
class RandomPatchCutter(_BasePatchCutter):
|
||||
|
||||
def __init__(self, patch_size, lr_scale):
|
||||
super(RandomPatchCutter, self).__init__(patch_size=patch_size, lr_scale=lr_scale)
|
||||
|
||||
def get_tx_ty(self, target_height, target_weight, target_patch_size):
|
||||
target_x = random.randrange(0, target_weight - target_patch_size + 1)
|
||||
target_y = random.randrange(0, target_height - target_patch_size + 1)
|
||||
return target_x, target_y
|
||||
|
||||
|
||||
class CentrePatchCutter(_BasePatchCutter):
|
||||
|
||||
def __init__(self, patch_size, lr_scale):
|
||||
super(CentrePatchCutter, self).__init__(patch_size=patch_size, lr_scale=lr_scale)
|
||||
|
||||
def get_tx_ty(self, target_height, target_weight, target_patch_size):
|
||||
target_x = (target_weight - target_patch_size) // 2
|
||||
target_y = (target_height - target_patch_size) // 2
|
||||
return target_x, target_y
|
||||
|
||||
|
||||
def hflip(img):
|
||||
return img[:, ::-1, :]
|
||||
|
||||
|
||||
def vflip(img):
|
||||
return img[::-1, :, :]
|
||||
|
||||
|
||||
def trnsp(img):
|
||||
return img.transpose(1, 0, 2)
|
||||
|
||||
|
||||
AUG_LIST = [
|
||||
[],
|
||||
[trnsp],
|
||||
[vflip],
|
||||
[vflip, trnsp],
|
||||
[hflip],
|
||||
[hflip, trnsp],
|
||||
[hflip, vflip],
|
||||
[hflip, vflip, trnsp],
|
||||
]
|
||||
|
||||
|
||||
AUG_DICT = {
|
||||
"0": [],
|
||||
"t": [trnsp],
|
||||
"v": [vflip],
|
||||
"vt": [vflip, trnsp],
|
||||
"h": [hflip],
|
||||
"ht": [hflip, trnsp],
|
||||
"hv": [hflip, vflip],
|
||||
"hvt": [hflip, vflip, trnsp],
|
||||
}
|
||||
|
||||
|
||||
def flip_and_rotate(*images):
|
||||
aug = random.choice(AUG_LIST)
|
||||
res = []
|
||||
for img in images:
|
||||
for a in aug:
|
||||
img = a(img)
|
||||
res.append(img)
|
||||
return tuple(res)
|
||||
|
||||
|
||||
def hwc2chw(*images):
|
||||
res = [i.transpose(2, 0, 1) for i in images]
|
||||
return tuple(res)
|
||||
|
||||
|
||||
def uint8_to_float32(*images):
|
||||
res = [(i.astype(np.float32) if i.dtype == np.uint8 else i) for i in images]
|
||||
return tuple(res)
|
||||
|
||||
|
||||
def create_dataset_DIV2K(config, dataset_type="train", num_parallel_workers=10, shuffle=True):
|
||||
"""
|
||||
create a train or eval DIV2K dataset
|
||||
Args:
|
||||
config(dict):
|
||||
dataset_path(string): the path of dataset.
|
||||
scale(int/list): lr scale, read data ordered by it, choices=(2,3,4,[2],[3],[4],[2,3],[2,4],[3,4],[2,3,4])
|
||||
lr_type(string): lr images type, choices=("bicubic", "unknown"), Default "bicubic"
|
||||
batch_size(int): the batch size of dataset. (train prarm), Default 1
|
||||
patch_size(int): train data size. (train param), Default -1
|
||||
epoch_size(int): times to repeat dataset for dataset_sink_mode, Default None
|
||||
dataset_type(string): choices=("train", "valid", "test"), Default "train"
|
||||
num_parallel_workers(int): num-workers to read data, Default 10
|
||||
shuffle(bool): shuffle dataset. Default: True
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
dataset_path = config["dataset_path"]
|
||||
lr_scale = config["scale"]
|
||||
lr_type = config.get("lr_type", "bicubic")
|
||||
batch_size = config.get("batch_size", 1)
|
||||
patch_size = config.get("patch_size", -1)
|
||||
epoch_size = config.get("epoch_size", None)
|
||||
|
||||
# for multi lr scale, such as [2,3,4]
|
||||
if isinstance(lr_scale, int):
|
||||
multi_lr_scale = [lr_scale]
|
||||
else:
|
||||
multi_lr_scale = lr_scale
|
||||
|
||||
# get HR_PATH/*.png
|
||||
dir_hr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_HR")
|
||||
hr_pattern = os.path.join(dir_hr, "*.png")
|
||||
|
||||
# get LR_PATH/X2/*x2.png, LR_PATH/X3/*x3.png, LR_PATH/X4/*x4.png
|
||||
column_names = []
|
||||
lrs_pattern = []
|
||||
for lr_scale in multi_lr_scale:
|
||||
dir_lr = os.path.join(dataset_path, f"DIV2K_{dataset_type}_LR_{lr_type}", f"X{lr_scale}")
|
||||
lr_pattern = os.path.join(dir_lr, f"*x{lr_scale}.png")
|
||||
lrs_pattern.append(lr_pattern)
|
||||
column_names.append(f"lrx{lr_scale}")
|
||||
column_names.append("hr") # ["lrx2","lrx3","lrx4",..., "hr"]
|
||||
|
||||
# make dataset
|
||||
dataset = LrHrImages(lr_pattern=lrs_pattern, hr_pattern=hr_pattern)
|
||||
|
||||
# make mindspore dataset
|
||||
device_num, rank_id = get_rank_info()
|
||||
if device_num == 1 or device_num is None:
|
||||
generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names,
|
||||
num_parallel_workers=num_parallel_workers,
|
||||
shuffle=shuffle and dataset_type == "train")
|
||||
elif dataset_type == "train":
|
||||
generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names,
|
||||
num_parallel_workers=num_parallel_workers,
|
||||
shuffle=shuffle and dataset_type == "train",
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
else:
|
||||
sampler = ds.DistributedSampler(num_shards=device_num, shard_id=rank_id, shuffle=False, offset=0)
|
||||
generator_dataset = ds.GeneratorDataset(dataset, column_names=column_names,
|
||||
num_parallel_workers=num_parallel_workers,
|
||||
sampler=sampler)
|
||||
|
||||
# define map operations
|
||||
if dataset_type == "train":
|
||||
transform_img = [
|
||||
RandomPatchCutter(patch_size, multi_lr_scale + [1]),
|
||||
flip_and_rotate,
|
||||
hwc2chw,
|
||||
uint8_to_float32,
|
||||
]
|
||||
elif patch_size > 0:
|
||||
transform_img = [
|
||||
CentrePatchCutter(patch_size, multi_lr_scale + [1]),
|
||||
hwc2chw,
|
||||
uint8_to_float32,
|
||||
]
|
||||
else:
|
||||
transform_img = [
|
||||
hwc2chw,
|
||||
uint8_to_float32,
|
||||
]
|
||||
|
||||
# pre-process hr lr
|
||||
generator_dataset = generator_dataset.map(input_columns=column_names,
|
||||
output_columns=column_names,
|
||||
column_order=column_names,
|
||||
operations=transform_img)
|
||||
|
||||
# apply batch operations
|
||||
generator_dataset = generator_dataset.batch(batch_size, drop_remainder=False)
|
||||
|
||||
# apply repeat operations
|
||||
if dataset_type == "train" and epoch_size is not None and epoch_size != 1:
|
||||
generator_dataset = generator_dataset.repeat(epoch_size)
|
||||
|
||||
return generator_dataset
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""EDSR"""
|
||||
import numpy as np
|
||||
from mindspore import Parameter
|
||||
from mindspore import nn, ops
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
|
||||
class RgbNormal(nn.Cell):
|
||||
"""
|
||||
"MeanShift" in EDSR paper pytorch-code:
|
||||
https://github.com/sanghyun-son/EDSR-PyTorch/blob/master/src/model/common.py
|
||||
|
||||
it is not unreasonable in the case below
|
||||
if std != 1 and sign = -1: y = x * rgb_std - rgb_range * rgb_mean
|
||||
if std != 1 and sign = 1: y = x * rgb_std + rgb_range * rgb_mean
|
||||
they are not inverse operation for each other!
|
||||
|
||||
so use "RgbNormal" instead, it runs as below:
|
||||
if inverse = False: y = (x / rgb_range - mean) / std
|
||||
if inverse = True : x = (y * std + mean) * rgb_range
|
||||
"""
|
||||
def __init__(self, rgb_range, rgb_mean, rgb_std, inverse=False):
|
||||
super(RgbNormal, self).__init__()
|
||||
self.rgb_range = rgb_range
|
||||
self.rgb_mean = rgb_mean
|
||||
self.rgb_std = rgb_std
|
||||
self.inverse = inverse
|
||||
std = np.array(self.rgb_std, dtype=np.float32)
|
||||
mean = np.array(self.rgb_mean, dtype=np.float32)
|
||||
if not inverse:
|
||||
# y: (x / rgb_range - mean) / std <=> x * (1.0 / rgb_range / std) + (-mean) / std
|
||||
weight = (1.0 / self.rgb_range / std).reshape((1, -1, 1, 1))
|
||||
bias = (-mean / std).reshape((1, -1, 1, 1))
|
||||
else:
|
||||
# x: (y * std + mean) * rgb_range <=> y * (std * rgb_range) + mean * rgb_range
|
||||
weight = (self.rgb_range * std).reshape((1, -1, 1, 1))
|
||||
bias = (mean * rgb_range).reshape((1, -1, 1, 1))
|
||||
self.weight = Parameter(weight, requires_grad=False)
|
||||
self.bias = Parameter(bias, requires_grad=False)
|
||||
|
||||
def construct(self, x):
|
||||
return x * self.weight + self.bias
|
||||
|
||||
def extend_repr(self):
|
||||
s = 'rgb_range={}, rgb_mean={}, rgb_std={}, inverse = {}' \
|
||||
.format(
|
||||
self.rgb_range,
|
||||
self.rgb_mean,
|
||||
self.rgb_std,
|
||||
self.inverse,
|
||||
)
|
||||
return s
|
||||
|
||||
|
||||
def make_conv2d(in_channels, out_channels, kernel_size, has_bias=True):
|
||||
return nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size,
|
||||
pad_mode="same", has_bias=has_bias, weight_init=TruncatedNormal(0.02))
|
||||
|
||||
|
||||
class ResBlock(nn.Cell):
|
||||
"""
|
||||
Resnet Block
|
||||
"""
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size=1, has_bias=True, res_scale=1):
|
||||
super(ResBlock, self).__init__()
|
||||
self.conv1 = make_conv2d(in_channels, in_channels, kernel_size, has_bias)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = make_conv2d(in_channels, out_channels, kernel_size, has_bias)
|
||||
self.res_scale = res_scale
|
||||
|
||||
def construct(self, x):
|
||||
res = self.conv1(x)
|
||||
res = self.relu(res)
|
||||
res = self.conv2(res)
|
||||
res = res * self.res_scale
|
||||
x = x + res
|
||||
return x
|
||||
|
||||
|
||||
class PixelShuffle(nn.Cell):
|
||||
"""
|
||||
PixelShuffle using ops.DepthToSpace
|
||||
"""
|
||||
def __init__(self, upscale_factor):
|
||||
super(PixelShuffle, self).__init__()
|
||||
self.upscale_factor = upscale_factor
|
||||
self.upper = ops.DepthToSpace(self.upscale_factor)
|
||||
|
||||
def construct(self, x):
|
||||
return self.upper(x)
|
||||
|
||||
def extend_repr(self):
|
||||
return 'upscale_factor={}'.format(self.upscale_factor)
|
||||
|
||||
|
||||
def UpsamplerBlockList(upscale_factor, n_feats, has_bias=True):
|
||||
"""
|
||||
make Upsampler Block List
|
||||
"""
|
||||
if upscale_factor == 1:
|
||||
return []
|
||||
allow_sub_upscale_factor = [2, 3, None]
|
||||
for sub in allow_sub_upscale_factor:
|
||||
if sub is None:
|
||||
raise NotImplementedError(
|
||||
f"Only support \"scales\" that can be divisibled by {allow_sub_upscale_factor[:-1]}")
|
||||
if upscale_factor % sub == 0:
|
||||
break
|
||||
sub_block_list = [
|
||||
make_conv2d(n_feats, sub*sub*n_feats, 3, has_bias),
|
||||
PixelShuffle(sub),
|
||||
]
|
||||
return sub_block_list + UpsamplerBlockList(upscale_factor // sub, n_feats, has_bias)
|
||||
|
||||
|
||||
class Upsampler(nn.Cell):
|
||||
|
||||
def __init__(self, scale, n_feats, has_bias=True):
|
||||
super(Upsampler, self).__init__()
|
||||
up = UpsamplerBlockList(scale, n_feats, has_bias)
|
||||
self.up = nn.SequentialCell(*up)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.up(x)
|
||||
return x
|
||||
|
||||
|
||||
class EDSR(nn.Cell):
|
||||
"""
|
||||
EDSR network
|
||||
"""
|
||||
def __init__(self, scale, n_feats, kernel_size, n_resblocks,
|
||||
n_colors=3,
|
||||
res_scale=0.1,
|
||||
rgb_range=255,
|
||||
rgb_mean=(0.0, 0.0, 0.0),
|
||||
rgb_std=(1.0, 1.0, 1.0)):
|
||||
super(EDSR, self).__init__()
|
||||
|
||||
self.norm = RgbNormal(rgb_range, rgb_mean, rgb_std, inverse=False)
|
||||
self.de_norm = RgbNormal(rgb_range, rgb_mean, rgb_std, inverse=True)
|
||||
|
||||
m_head = [make_conv2d(n_colors, n_feats, kernel_size)]
|
||||
|
||||
m_body = [
|
||||
ResBlock(n_feats, n_feats, kernel_size, res_scale=res_scale)
|
||||
for _ in range(n_resblocks)
|
||||
]
|
||||
m_body.append(make_conv2d(n_feats, n_feats, kernel_size))
|
||||
|
||||
m_tail = [
|
||||
Upsampler(scale, n_feats),
|
||||
make_conv2d(n_feats, n_colors, kernel_size)
|
||||
]
|
||||
|
||||
self.head = nn.SequentialCell(m_head)
|
||||
self.body = nn.SequentialCell(m_body)
|
||||
self.tail = nn.SequentialCell(m_tail)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.norm(x)
|
||||
x = self.head(x)
|
||||
x = x + self.body(x)
|
||||
x = self.tail(x)
|
||||
x = self.de_norm(x)
|
||||
return x
|
||||
|
||||
def load_pre_trained_param_dict(self, new_param_dict, strict=True):
|
||||
"""
|
||||
load pre_trained param dict from edsr_x2
|
||||
"""
|
||||
own_param = self.parameters_dict()
|
||||
for name, new_param in new_param_dict.items():
|
||||
if len(name) >= 4 and name[:4] == "net.":
|
||||
name = name[4:]
|
||||
if name in own_param:
|
||||
if isinstance(new_param, Parameter):
|
||||
param = own_param[name]
|
||||
if tuple(param.data.shape) == tuple(new_param.data.shape):
|
||||
param.set_data(type(param.data)(new_param.data))
|
||||
elif name.find('tail') == -1:
|
||||
raise RuntimeError('While copying the parameter named {}, '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}.'
|
||||
.format(name, own_param[name].shape, new_param.shape))
|
||||
elif strict:
|
||||
if name.find('tail') == -1:
|
||||
raise KeyError('unexpected key "{}" in parameters_dict()'
|
||||
.format(name))
|
|
@ -0,0 +1,330 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Metric for evaluation."""
|
||||
import os
|
||||
import math
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from mindspore import nn, Tensor, ops
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops.operations.comm_ops import ReduceOp
|
||||
|
||||
try:
|
||||
from model_utils.device_adapter import get_rank_id, get_device_num
|
||||
except ImportError:
|
||||
get_rank_id = None
|
||||
get_device_num = None
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
class SelfEnsembleWrapperNumpy:
|
||||
"""
|
||||
SelfEnsembleWrapperNumpy using numpy
|
||||
"""
|
||||
|
||||
def __init__(self, net):
|
||||
super(SelfEnsembleWrapperNumpy, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def hflip(self, x):
|
||||
return x[:, :, :, ::-1]
|
||||
|
||||
def vflip(self, x):
|
||||
return x[:, :, ::-1, :]
|
||||
|
||||
def trnsps(self, x):
|
||||
return x.transpose(0, 1, 3, 2)
|
||||
|
||||
def aug_x8(self, x):
|
||||
"""
|
||||
do x8 augments for input image
|
||||
"""
|
||||
# hflip
|
||||
hx = self.hflip(x)
|
||||
# vflip
|
||||
vx = self.vflip(x)
|
||||
vhx = self.vflip(hx)
|
||||
# trnsps
|
||||
tx = self.trnsps(x)
|
||||
thx = self.trnsps(hx)
|
||||
tvx = self.trnsps(vx)
|
||||
tvhx = self.trnsps(vhx)
|
||||
return x, hx, vx, vhx, tx, thx, tvx, tvhx
|
||||
|
||||
def aug_x8_reverse(self, x, hx, vx, vhx, tx, thx, tvx, tvhx):
|
||||
"""
|
||||
undo x8 augments for input images
|
||||
"""
|
||||
# trnsps
|
||||
tvhx = self.trnsps(tvhx)
|
||||
tvx = self.trnsps(tvx)
|
||||
thx = self.trnsps(thx)
|
||||
tx = self.trnsps(tx)
|
||||
# vflip
|
||||
tvhx = self.vflip(tvhx)
|
||||
tvx = self.vflip(tvx)
|
||||
vhx = self.vflip(vhx)
|
||||
vx = self.vflip(vx)
|
||||
# hflip
|
||||
tvhx = self.hflip(tvhx)
|
||||
thx = self.hflip(thx)
|
||||
vhx = self.hflip(vhx)
|
||||
hx = self.hflip(hx)
|
||||
return x, hx, vx, vhx, tx, thx, tvx, tvhx
|
||||
|
||||
def to_numpy(self, *inputs):
|
||||
if inputs:
|
||||
return None
|
||||
if len(inputs) == 1:
|
||||
return inputs[0].asnumpy()
|
||||
return [x.asnumpy() for x in inputs]
|
||||
|
||||
def to_tensor(self, *inputs):
|
||||
if inputs:
|
||||
return None
|
||||
if len(inputs) == 1:
|
||||
return Tensor(inputs[0])
|
||||
return [Tensor(x) for x in inputs]
|
||||
|
||||
def set_train(self, mode=True):
|
||||
self.net.set_train(mode)
|
||||
return self
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.to_numpy(x)
|
||||
x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8(x)
|
||||
x0, x1, x2, x3, x4, x5, x6, x7 = self.to_tensor(x0, x1, x2, x3, x4, x5, x6, x7)
|
||||
x0 = self.net(x0)
|
||||
x1 = self.net(x1)
|
||||
x2 = self.net(x2)
|
||||
x3 = self.net(x3)
|
||||
x4 = self.net(x4)
|
||||
x5 = self.net(x5)
|
||||
x6 = self.net(x6)
|
||||
x7 = self.net(x7)
|
||||
x0, x1, x2, x3, x4, x5, x6, x7 = self.to_numpy(x0, x1, x2, x3, x4, x5, x6, x7)
|
||||
x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8_reverse(x0, x1, x2, x3, x4, x5, x6, x7)
|
||||
x0, x1, x2, x3, x4, x5, x6, x7 = self.to_tensor(x0, x1, x2, x3, x4, x5, x6, x7)
|
||||
return (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8
|
||||
|
||||
|
||||
class SelfEnsembleWrapper(nn.Cell):
|
||||
"""
|
||||
because of [::-1] operator error, use "SelfEnsembleWrapperNumpy" instead
|
||||
"""
|
||||
def __init__(self, net):
|
||||
super(SelfEnsembleWrapper, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def hflip(self, x):
|
||||
raise NotImplementedError("https://gitee.com/mindspore/mindspore/issues/I41ONQ?from=project-issue")
|
||||
|
||||
def vflip(self, x):
|
||||
raise NotImplementedError("https://gitee.com/mindspore/mindspore/issues/I41ONQ?from=project-issue")
|
||||
|
||||
def trnsps(self, x):
|
||||
return x.transpose(0, 1, 3, 2)
|
||||
|
||||
def aug_x8(self, x):
|
||||
"""
|
||||
do x8 augments for input image
|
||||
"""
|
||||
# hflip
|
||||
hx = self.hflip(x)
|
||||
# vflip
|
||||
vx = self.vflip(x)
|
||||
vhx = self.vflip(hx)
|
||||
# trnsps
|
||||
tx = self.trnsps(x)
|
||||
thx = self.trnsps(hx)
|
||||
tvx = self.trnsps(vx)
|
||||
tvhx = self.trnsps(vhx)
|
||||
return x, hx, vx, vhx, tx, thx, tvx, tvhx
|
||||
|
||||
def aug_x8_reverse(self, x, hx, vx, vhx, tx, thx, tvx, tvhx):
|
||||
"""
|
||||
undo x8 augments for input images
|
||||
"""
|
||||
# trnsps
|
||||
tvhx = self.trnsps(tvhx)
|
||||
tvx = self.trnsps(tvx)
|
||||
thx = self.trnsps(thx)
|
||||
tx = self.trnsps(tx)
|
||||
# vflip
|
||||
tvhx = self.vflip(tvhx)
|
||||
tvx = self.vflip(tvx)
|
||||
vhx = self.vflip(vhx)
|
||||
vx = self.vflip(vx)
|
||||
# hflip
|
||||
tvhx = self.hflip(tvhx)
|
||||
thx = self.hflip(thx)
|
||||
vhx = self.hflip(vhx)
|
||||
hx = self.hflip(hx)
|
||||
return x, hx, vx, vhx, tx, thx, tvx, tvhx
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
do x8 aug, run network, undo x8 aug, calculate mean for 8 output
|
||||
"""
|
||||
x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8(x)
|
||||
x0 = self.net(x0)
|
||||
x1 = self.net(x1)
|
||||
x2 = self.net(x2)
|
||||
x3 = self.net(x3)
|
||||
x4 = self.net(x4)
|
||||
x5 = self.net(x5)
|
||||
x6 = self.net(x6)
|
||||
x7 = self.net(x7)
|
||||
x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8_reverse(x0, x1, x2, x3, x4, x5, x6, x7)
|
||||
return (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8
|
||||
|
||||
|
||||
class Quantizer(nn.Cell):
|
||||
"""
|
||||
clip by [0.0, 255.0], rount to int
|
||||
"""
|
||||
def __init__(self, _min=0.0, _max=255.0):
|
||||
super(Quantizer, self).__init__()
|
||||
self.round = ops.Round()
|
||||
self._min = _min
|
||||
self._max = _max
|
||||
|
||||
def construct(self, x):
|
||||
x = ops.clip_by_value(x, self._min, self._max)
|
||||
x = self.round(x)
|
||||
return x
|
||||
|
||||
|
||||
class TensorSyncer(nn.Cell):
|
||||
"""
|
||||
sync metric values from all mindspore-processes
|
||||
"""
|
||||
def __init__(self, _type="sum"):
|
||||
super(TensorSyncer, self).__init__()
|
||||
self._type = _type.lower()
|
||||
if self._type == "sum":
|
||||
self.ops = ops.AllReduce(ReduceOp.SUM)
|
||||
elif self._type == "gather":
|
||||
self.ops = ops.AllGather()
|
||||
else:
|
||||
raise ValueError(f"TensorSyncer._type == {self._type} is not support")
|
||||
|
||||
def construct(self, x):
|
||||
return self.ops(x)
|
||||
|
||||
|
||||
class _DistMetric(nn.Metric):
|
||||
"""
|
||||
gather data from all rank while eval(True)
|
||||
_type(str): choice from ["avg", "sum"].
|
||||
"""
|
||||
def __init__(self, _type):
|
||||
super(_DistMetric, self).__init__()
|
||||
self._type = _type.lower()
|
||||
self.all_reduce_sum = None
|
||||
if get_device_num is not None and get_device_num() > 1:
|
||||
self.all_reduce_sum = TensorSyncer(_type="sum")
|
||||
self.clear()
|
||||
|
||||
def _accumulate(self, value):
|
||||
if isinstance(value, (list, tuple)):
|
||||
self._acc_value += sum(value)
|
||||
self._count += len(value)
|
||||
else:
|
||||
self._acc_value += value
|
||||
self._count += 1
|
||||
|
||||
def clear(self):
|
||||
self._acc_value = 0.0
|
||||
self._count = 0
|
||||
|
||||
def eval(self, sync=True):
|
||||
"""
|
||||
sync: True, return metric value merged from all mindspore-processes
|
||||
sync: False, return metric value in this single mindspore-processes
|
||||
"""
|
||||
if self._count == 0:
|
||||
raise RuntimeError('self._count == 0')
|
||||
if self.sum is not None and sync:
|
||||
data = Tensor([self._acc_value, self._count], mstype.float32)
|
||||
data = self.all_reduce_sum(data)
|
||||
acc_value, count = self._convert_data(data).tolist()
|
||||
else:
|
||||
acc_value, count = self._acc_value, self._count
|
||||
if self._type == "avg":
|
||||
return acc_value / count
|
||||
if self._type == "sum":
|
||||
return acc_value
|
||||
raise RuntimeError(f"_DistMetric._type={self._type} is not support")
|
||||
|
||||
|
||||
class PSNR(_DistMetric):
|
||||
"""
|
||||
Define PSNR metric for SR network.
|
||||
"""
|
||||
def __init__(self, rgb_range, shave):
|
||||
super(PSNR, self).__init__(_type="avg")
|
||||
self.shave = shave
|
||||
self.rgb_range = rgb_range
|
||||
self.quantize = Quantizer(0.0, 255.0)
|
||||
|
||||
def update(self, *inputs):
|
||||
"""
|
||||
update psnr
|
||||
"""
|
||||
if len(inputs) != 2:
|
||||
raise ValueError('PSNR need 2 inputs (sr, hr), but got {}'.format(len(inputs)))
|
||||
sr, hr = inputs
|
||||
sr = self.quantize(sr)
|
||||
diff = (sr - hr) / self.rgb_range
|
||||
valid = diff
|
||||
if self.shave is not None and self.shave != 0:
|
||||
valid = valid[..., self.shave:(-self.shave), self.shave:(-self.shave)]
|
||||
mse_list = (valid ** 2).mean(axis=(1, 2, 3))
|
||||
mse_list = self._convert_data(mse_list).tolist()
|
||||
psnr_list = [float(1e32) if mse == 0 else(- 10.0 * math.log10(mse)) for mse in mse_list]
|
||||
self._accumulate(psnr_list)
|
||||
|
||||
|
||||
class SaveSrHr(_DistMetric):
|
||||
"""
|
||||
help to save sr and hr
|
||||
"""
|
||||
def __init__(self, save_dir):
|
||||
super(SaveSrHr, self).__init__(_type="sum")
|
||||
self.save_dir = save_dir
|
||||
self.quantize = Quantizer(0.0, 255.0)
|
||||
self.rank_id = 0 if get_rank_id is None else get_rank_id()
|
||||
self.device_num = 1 if get_device_num is None else get_device_num()
|
||||
|
||||
def update(self, *inputs):
|
||||
"""
|
||||
update images to save
|
||||
"""
|
||||
if len(inputs) != 2:
|
||||
raise ValueError('SaveSrHr need 2 inputs (sr, hr), but got {}'.format(len(inputs)))
|
||||
sr, hr = inputs
|
||||
sr = self.quantize(sr)
|
||||
sr = self._convert_data(sr).astype(np.uint8)
|
||||
hr = self._convert_data(hr).astype(np.uint8)
|
||||
for s, h in zip(sr.transpose(0, 2, 3, 1), hr.transpose(0, 2, 3, 1)):
|
||||
idx = self._count * self.device_num + self.rank_id
|
||||
sr_path = os.path.join(self.save_dir, f"{idx:0>4}_sr.png")
|
||||
Image.fromarray(s).save(sr_path)
|
||||
hr_path = os.path.join(self.save_dir, f"{idx:0>4}_hr.png")
|
||||
Image.fromarray(h).save(hr_path)
|
||||
self._accumulate(1)
|
|
@ -1,102 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""metrics"""
|
||||
import math
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def quantize(img, rgb_range):
|
||||
"""quantize image range to 0-255"""
|
||||
pixel_range = 255 / rgb_range
|
||||
img = np.multiply(img, pixel_range)
|
||||
img = np.clip(img, 0, 255)
|
||||
img = np.round(img) / pixel_range
|
||||
return img
|
||||
|
||||
|
||||
def calc_psnr(sr, hr, scale, rgb_range):
|
||||
"""calculate psnr"""
|
||||
hr = np.float32(hr)
|
||||
sr = np.float32(sr)
|
||||
diff = (sr - hr) / rgb_range
|
||||
gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256
|
||||
diff = np.multiply(diff, gray_coeffs).sum(1)
|
||||
if hr.size == 1:
|
||||
return 0
|
||||
if scale != 1:
|
||||
shave = scale
|
||||
else:
|
||||
shave = scale + 6
|
||||
if scale == 1:
|
||||
valid = diff
|
||||
else:
|
||||
valid = diff[..., shave:-shave, shave:-shave]
|
||||
mse = np.mean(pow(valid, 2))
|
||||
return -10 * math.log10(mse)
|
||||
|
||||
|
||||
def rgb2ycbcr(img, y_only=True):
|
||||
"""from rgb space to ycbcr space"""
|
||||
img.astype(np.float32)
|
||||
if y_only:
|
||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
||||
return rlt
|
||||
|
||||
|
||||
def calc_ssim(img1, img2, scale):
|
||||
"""calculate ssim value"""
|
||||
def ssim(img1, img2):
|
||||
C1 = (0.01 * 255) ** 2
|
||||
C2 = (0.03 * 255) ** 2
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1 ** 2
|
||||
mu2_sq = mu2 ** 2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||
(sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
border = 0
|
||||
if scale != 1:
|
||||
border = scale
|
||||
else:
|
||||
border = scale + 6
|
||||
img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 256.0 + 16.0
|
||||
img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 256.0 + 16.0
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError('Input images must have the same dimensions.')
|
||||
h, w = img1.shape[:2]
|
||||
img1_y = img1_y[border:h - border, border:w - border]
|
||||
img2_y = img2_y[border:h - border, border:w - border]
|
||||
if img1_y.ndim == 2:
|
||||
return ssim(img1_y, img2_y)
|
||||
if img1.ndim == 3:
|
||||
if img1.shape[2] == 3:
|
||||
ssims = []
|
||||
for _ in range(3):
|
||||
ssims.append(ssim(img1, img2))
|
||||
return np.array(ssims).mean()
|
||||
if img1.shape[2] == 1:
|
||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
||||
else:
|
||||
raise ValueError('Wrong input image dimensions.')
|
|
@ -1,65 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""edsr_model"""
|
||||
import mindspore.nn as nn
|
||||
from src import common
|
||||
|
||||
|
||||
class EDSR(nn.Cell):
|
||||
"""EDSR"""
|
||||
def __init__(self, args, conv=common.default_conv):
|
||||
super(EDSR, self).__init__()
|
||||
|
||||
n_resblocks = args.n_resblocks
|
||||
n_feats = args.n_feats
|
||||
kernel_size = 3
|
||||
scale = args.scale[0]
|
||||
act = nn.ReLU()
|
||||
|
||||
self.sub_mean = common.MeanShift(args.rgb_range)
|
||||
self.add_mean = common.MeanShift(args.rgb_range, sign=1)
|
||||
|
||||
# define head module
|
||||
m_head = [conv(args.n_colors, n_feats, kernel_size)]
|
||||
|
||||
# define body module
|
||||
m_body = [
|
||||
common.ResBlock(
|
||||
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
|
||||
) for _ in range(n_resblocks)
|
||||
]
|
||||
m_body.append(conv(n_feats, n_feats, kernel_size))
|
||||
|
||||
# define tail module
|
||||
m_tail = []
|
||||
m_tail += common.Upsampler(conv, scale, n_feats)
|
||||
m_tail.append(conv(n_feats, args.n_colors, kernel_size))
|
||||
|
||||
self.head = nn.SequentialCell(m_head)
|
||||
self.body = nn.SequentialCell(m_body)
|
||||
self.tail = nn.SequentialCell(m_tail)
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
x = self.sub_mean(x)
|
||||
x = self.head(x)
|
||||
|
||||
res = self.body(x)
|
||||
res += x
|
||||
|
||||
x = self.tail(res)
|
||||
x = self.add_mean(x)
|
||||
|
||||
return x
|
|
@ -12,76 +12,179 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""edsr train wrapper"""
|
||||
"""
|
||||
#################utils for train.py and eval.py########################
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
class Trainer():
|
||||
"""Trainer"""
|
||||
def __init__(self, args, loader, my_model):
|
||||
self.args = args
|
||||
self.scale = args.scale
|
||||
self.trainloader = loader
|
||||
self.model = my_model
|
||||
self.model.set_train()
|
||||
self.criterion = nn.L1Loss()
|
||||
self.loss_history = []
|
||||
self.begin_time = time.time()
|
||||
self.optimizer = nn.Adam(self.model.trainable_params(), learning_rate=args.lr, loss_scale=1024.0)
|
||||
self.loss_net = nn.WithLossCell(self.model, self.criterion)
|
||||
self.net = nn.TrainOneStepCell(self.loss_net, self.optimizer)
|
||||
def train(self, epoch):
|
||||
"""Trainer"""
|
||||
losses = 0
|
||||
batch_idx = 0
|
||||
for batch_idx, imgs in enumerate(self.trainloader):
|
||||
lr = imgs["LR"]
|
||||
hr = imgs["HR"]
|
||||
lr = Tensor(lr, mstype.float32)
|
||||
hr = Tensor(hr, mstype.float32)
|
||||
t1 = time.time()
|
||||
loss = self.net(lr, hr)
|
||||
t2 = time.time()
|
||||
losses += loss.asnumpy()
|
||||
print('Epoch: %g, Step: %g , loss: %f, time: %f s ' % \
|
||||
(epoch, batch_idx, loss.asnumpy(), t2 - t1), end='\n', flush=True)
|
||||
print("the epoch loss is", losses / (batch_idx + 1), flush=True)
|
||||
self.loss_history.append(losses / (batch_idx + 1))
|
||||
print(self.loss_history)
|
||||
t = time.time() - self.begin_time
|
||||
t = int(t)
|
||||
print(", running time: %gh%g'%g''"%(t//3600, (t-t//3600*3600)//60, t%60), flush=True)
|
||||
os.makedirs(self.args.save, exist_ok=True)
|
||||
if self.args.rank == 0 and (epoch+1)%10 == 0:
|
||||
save_checkpoint(self.net, self.args.save + "model_" + str(self.epoch) + '.ckpt')
|
||||
def update_learning_rate(self, epoch):
|
||||
"""Update learning rates for all the networks; called at the end of every epoch.
|
||||
:param epoch: current epoch
|
||||
:type epoch: int
|
||||
:param lr: learning rate of cyclegan
|
||||
:type lr: float
|
||||
:param niter: number of epochs with the initial learning rate
|
||||
:type niter: int
|
||||
:param niter_decay: number of epochs to linearly decay learning rate to zero
|
||||
:type niter_decay: int
|
||||
"""
|
||||
self.epoch = epoch
|
||||
print("*********** epoch: {} **********".format(epoch))
|
||||
lr = self.args.lr / (2 ** ((epoch+1)//200))
|
||||
self.adjust_lr('model', self.optimizer, lr)
|
||||
print("*********************************")
|
||||
def adjust_lr(self, name, optimizer, lr):
|
||||
"""Adjust learning rate for the corresponding model.
|
||||
:param name: name of model
|
||||
:type name: str
|
||||
:param optimizer: the optimizer of the corresponding model
|
||||
:type optimizer: torch.optim
|
||||
:param lr: learning rate to be adjusted
|
||||
:type lr: float
|
||||
"""
|
||||
lr_param = optimizer.get_lr()
|
||||
lr_param.assign_value(Tensor(lr, mstype.float32))
|
||||
print('==> ' + name + ' learning rate: ', lr_param.asnumpy())
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
|
||||
from model_utils.config import config
|
||||
from model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
|
||||
from .dataset import create_dataset_DIV2K
|
||||
from .edsr import EDSR
|
||||
|
||||
|
||||
def init_env(cfg):
|
||||
"""
|
||||
init env for mindspore
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
|
||||
device_num = get_device_num()
|
||||
if cfg.device_target == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
elif cfg.device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
elif cfg.device_target == "CPU":
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
|
||||
def init_dataset(cfg, dataset_type="train"):
|
||||
"""
|
||||
init DIV2K dataset
|
||||
"""
|
||||
ds_cfg = {
|
||||
"dataset_path": cfg.data_path,
|
||||
"scale": cfg.scale,
|
||||
"lr_type": cfg.lr_type,
|
||||
"batch_size": cfg.batch_size,
|
||||
"patch_size": cfg.patch_size,
|
||||
}
|
||||
if cfg.dataset_name == "DIV2K":
|
||||
dataset = create_dataset_DIV2K(config=ds_cfg,
|
||||
dataset_type=dataset_type,
|
||||
num_parallel_workers=10,
|
||||
shuffle=dataset_type == "Train")
|
||||
else:
|
||||
raise ValueError("Unsupported dataset.")
|
||||
return dataset
|
||||
|
||||
|
||||
def init_net(cfg):
|
||||
"""
|
||||
init edsr network
|
||||
"""
|
||||
net = EDSR(scale=cfg.scale,
|
||||
n_feats=cfg.n_feats,
|
||||
kernel_size=cfg.kernel_size,
|
||||
n_resblocks=cfg.n_resblocks,
|
||||
n_colors=cfg.n_colors,
|
||||
res_scale=cfg.res_scale,
|
||||
rgb_range=cfg.rgb_range,
|
||||
rgb_mean=cfg.rgb_mean,
|
||||
rgb_std=cfg.rgb_std,)
|
||||
if cfg.pre_trained:
|
||||
pre_trained_path = os.path.join(cfg.output_path, cfg.pre_trained)
|
||||
if len(cfg.pre_trained) >= 5 and cfg.pre_trained[:5] == "s3://":
|
||||
pre_trained_path = cfg.pre_trained
|
||||
import moxing as mox
|
||||
mox.file.shift("os", "mox") # then system can read file from s3://
|
||||
elif os.path.isfile(cfg.pre_trained):
|
||||
pre_trained_path = cfg.pre_trained
|
||||
elif os.path.isfile(pre_trained_path):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"pre_trained error: {cfg.pre_trained}")
|
||||
print(f"loading pre_trained = {pre_trained_path}", flush=True)
|
||||
param_dict = load_checkpoint(pre_trained_path)
|
||||
net.load_pre_trained_param_dict(param_dict, strict=False)
|
||||
return net
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
zip_name = os.path.basename(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
data_print = int(data_num / 4) if data_num > 4 else 1
|
||||
len_data_num = len(str(data_num))
|
||||
for i, _file in enumerate(fz.namelist()):
|
||||
if i % data_print == 0:
|
||||
print("[{1:>{0}}/{2:>{0}}] {3:>2}% const time: {4:0>2}:{5:0>2} unzipping {6}".format(
|
||||
len_data_num,
|
||||
i,
|
||||
data_num,
|
||||
int(i / data_num * 100),
|
||||
int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60),
|
||||
zip_name,
|
||||
flush=True))
|
||||
fz.extract(_file, save_dir)
|
||||
print(" finish const time: {:0>2}:{:0>2} unzipping {}".format(
|
||||
int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60),
|
||||
zip_name,
|
||||
flush=True))
|
||||
else:
|
||||
print("{} is not zip.".format(zip_name), flush=True)
|
||||
|
||||
if config.enable_modelarts and config.need_unzip_in_modelarts:
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
for ufile in config.need_unzip_files:
|
||||
zip_file = os.path.join(config.data_path, ufile)
|
||||
save_dir = os.path.dirname(zip_file)
|
||||
unzip(zip_file, save_dir)
|
||||
print("===Finish extract data synchronization===", flush=True)
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data.".format(get_device_id()), flush=True)
|
||||
|
||||
config.ckpt_save_dir = os.path.join(config.output_path, config.ckpt_save_dir)
|
||||
|
||||
|
||||
def do_eval(eval_network, ds_val, metrics, cur_epoch=None):
|
||||
"""
|
||||
do eval for psnr and save hr, sr
|
||||
"""
|
||||
eval_network.set_train(False)
|
||||
total_step = ds_val.get_dataset_size()
|
||||
setw = len(str(total_step))
|
||||
begin = time.time()
|
||||
step_begin = time.time()
|
||||
rank_id = get_rank_id()
|
||||
for i, (lr, hr) in enumerate(ds_val):
|
||||
sr = eval_network(lr)
|
||||
_ = [m.update(sr, hr) for m in metrics.values()]
|
||||
result = {k: m.eval(sync=False) for k, m in metrics.items()}
|
||||
result["time"] = time.time() - step_begin
|
||||
step_begin = time.time()
|
||||
print(f"[{i+1:>{setw}}/{total_step:>{setw}}] rank = {rank_id} result = {result}", flush=True)
|
||||
result = {k: m.eval(sync=True) for k, m in metrics.items()}
|
||||
result["time"] = time.time() - begin
|
||||
if cur_epoch is not None:
|
||||
result["epoch"] = cur_epoch
|
||||
if rank_id == 0:
|
||||
print(f"evaluation result = {result}", flush=True)
|
||||
eval_network.set_train(True)
|
||||
return result
|
||||
|
|
|
@ -12,62 +12,138 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""edsr train script"""
|
||||
import os
|
||||
from mindspore import context
|
||||
from mindspore import dataset as ds
|
||||
"""
|
||||
#################train EDSR example on DIV2K########################
|
||||
"""
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.common import set_seed
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Callback
|
||||
from mindspore.train.model import Model
|
||||
from src.args import args
|
||||
from src.data.div2k import DIV2K
|
||||
from src.model import EDSR
|
||||
from mindspore.common import set_seed
|
||||
|
||||
def train_net():
|
||||
"""train edsr"""
|
||||
set_seed(1)
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
rank_id = int(os.getenv('RANK_ID', '0'))
|
||||
device_num = int(os.getenv('RANK_SIZE', '1'))
|
||||
# if distribute:
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
device_num=device_num, gradients_mean=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
|
||||
train_dataset.set_scale(args.task_id)
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num,
|
||||
shard_id=rank_id, shuffle=True)
|
||||
train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
|
||||
net_m = EDSR(args)
|
||||
print("Init net successfully")
|
||||
if args.ckpt_path:
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(net_m, param_dict)
|
||||
print("Load net weight successfully")
|
||||
step_size = train_de_dataset.get_dataset_size()
|
||||
lr = []
|
||||
for i in range(0, args.epochs):
|
||||
cur_lr = args.lr / (2 ** ((i + 1)//200))
|
||||
lr.extend([cur_lr] * step_size)
|
||||
opt = nn.Adam(net_m.trainable_params(), learning_rate=lr, loss_scale=args.loss_scale)
|
||||
loss = nn.L1Loss()
|
||||
model = Model(net_m, loss_fn=loss, optimizer=opt)
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
from src.metric import PSNR
|
||||
from src.utils import init_env, init_dataset, init_net, modelarts_pre_process, do_eval
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_rank_id, get_device_num
|
||||
|
||||
|
||||
set_seed(2021)
|
||||
|
||||
|
||||
def lr_steps_edsr(lr, milestones, gamma, epoch_size, steps_per_epoch, last_epoch=None):
|
||||
lr_each_step = []
|
||||
step_begin_epoch = [0] + milestones[:-1]
|
||||
step_end_epoch = milestones[1:] + [epoch_size]
|
||||
for begin, end in zip(step_begin_epoch, step_end_epoch):
|
||||
lr_each_step += [lr] * (end - begin) * steps_per_epoch
|
||||
lr *= gamma
|
||||
if last_epoch is not None:
|
||||
lr_each_step = lr_each_step[last_epoch * steps_per_epoch:]
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def init_opt(cfg, net):
|
||||
"""
|
||||
init opt to train edsr
|
||||
"""
|
||||
lr = lr_steps_edsr(lr=cfg.learning_rate, milestones=cfg.milestones, gamma=cfg.gamma,
|
||||
epoch_size=cfg.epoch_size, steps_per_epoch=cfg.steps_per_epoch, last_epoch=None)
|
||||
loss_scale = 1.0 if cfg.amp_level == "O0" else cfg.loss_scale
|
||||
if cfg.opt_type == "Adam":
|
||||
opt = nn.Adam(params=filter(lambda x: x.requires_grad, net.get_parameters()),
|
||||
learning_rate=Tensor(lr),
|
||||
weight_decay=cfg.weight_decay,
|
||||
loss_scale=loss_scale)
|
||||
elif cfg.opt_type == "SGD":
|
||||
opt = nn.SGD(params=filter(lambda x: x.requires_grad, net.get_parameters()),
|
||||
learning_rate=Tensor(lr),
|
||||
weight_decay=cfg.weight_decay,
|
||||
momentum=cfg.momentum,
|
||||
dampening=cfg.dampening if hasattr(cfg, "dampening") else 0.0,
|
||||
nesterov=cfg.nesterov if hasattr(cfg, "nesterov") else False,
|
||||
loss_scale=loss_scale)
|
||||
else:
|
||||
raise ValueError("Unsupported optimizer.")
|
||||
return opt
|
||||
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
eval callback
|
||||
"""
|
||||
def __init__(self, eval_network, ds_val, eval_epoch_frq, epoch_size, metrics, result_evaluation=None):
|
||||
self.eval_network = eval_network
|
||||
self.ds_val = ds_val
|
||||
self.eval_epoch_frq = eval_epoch_frq
|
||||
self.epoch_size = epoch_size
|
||||
self.result_evaluation = result_evaluation
|
||||
self.metrics = metrics
|
||||
self.best_result = None
|
||||
self.eval_network.set_train(False)
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""
|
||||
do eval in epoch end
|
||||
"""
|
||||
cb_param = run_context.original_args()
|
||||
cur_epoch = cb_param.cur_epoch_num
|
||||
if cur_epoch % self.eval_epoch_frq == 0 or cur_epoch == self.epoch_size:
|
||||
result = do_eval(self.eval_network, self.ds_val, self.metrics, cur_epoch=cur_epoch)
|
||||
if self.best_result is None or self.best_result["psnr"] < result["psnr"]:
|
||||
self.best_result = result
|
||||
if get_rank_id() == 0:
|
||||
print(f"best evaluation result = {self.best_result}", flush=True)
|
||||
if isinstance(self.result_evaluation, dict):
|
||||
for k, v in result.items():
|
||||
r_list = self.result_evaluation.get(k)
|
||||
if r_list is None:
|
||||
r_list = []
|
||||
self.result_evaluation[k] = r_list
|
||||
r_list.append(v)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
"""
|
||||
run train
|
||||
"""
|
||||
print(config)
|
||||
cfg = config
|
||||
|
||||
init_env(cfg)
|
||||
|
||||
ds_train = init_dataset(cfg, "train")
|
||||
ds_val = init_dataset(cfg, "valid")
|
||||
|
||||
net = init_net(cfg)
|
||||
cfg.steps_per_epoch = ds_train.get_dataset_size()
|
||||
opt = init_opt(cfg, net)
|
||||
|
||||
loss = nn.L1Loss(reduction='mean')
|
||||
|
||||
eval_net = net
|
||||
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, amp_level=cfg.amp_level)
|
||||
|
||||
metrics = {
|
||||
"psnr": PSNR(rgb_range=cfg.rgb_range, shave=True),
|
||||
}
|
||||
eval_cb = EvalCallBack(eval_net, ds_val, cfg.eval_epoch_frq, cfg.epoch_size, metrics=metrics)
|
||||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.steps_per_epoch * cfg.save_epoch_frq,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
time_cb = TimeMonitor()
|
||||
ckpoint_cb = ModelCheckpoint(prefix=f"EDSR_x{cfg.scale}_" + cfg.dataset_name, directory=cfg.ckpt_save_dir,
|
||||
config=config_ck)
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args.ckpt_save_interval * step_size,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="edsr", directory=args.ckpt_save_path, config=config_ck)
|
||||
if device_id == 0:
|
||||
cb += [ckpt_cb]
|
||||
model.train(args.epochs, train_de_dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
cbs = [time_cb, ckpoint_cb, loss_cb, eval_cb]
|
||||
if get_device_num() > 1 and get_rank_id() != 0:
|
||||
cbs = [time_cb, loss_cb, eval_cb]
|
||||
|
||||
model.train(cfg.epoch_size, ds_train, dataset_sink_mode=cfg.dataset_sink_mode, callbacks=cbs)
|
||||
print("train success", flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_net()
|
||||
if __name__ == '__main__':
|
||||
run_train()
|
||||
|
|
Loading…
Reference in New Issue