!17158 新增图像去噪网络 BRDNet 到 master分支
Merge pull request !17158 from 谭华林/master
This commit is contained in:
commit
8720c66af8
|
@ -0,0 +1,532 @@
|
|||
# BRDNet
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [BRDNet](#BRDNet)
|
||||
- [BRDNet介绍](#BRDNet介绍)
|
||||
- [模型结构](#模型结构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [310推理](#310推理)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
## BRDNet介绍
|
||||
|
||||
BRDNet 于 2020 年发表在人工智能期刊 Neural Networks,并被该期刊评为 2019/2020 年最高下载量之一的论文。该文提出了在硬件资源受限条件处理数据分步不均匀问题的方案,同时首次提出利用双路网络提取互补信息的思路来进行图像去噪,能有效去除合成噪声和真实噪声。
|
||||
|
||||
[论文](https://www.sciencedirect.com/science/article/pii/S0893608019302394):Ct A , Yong X , Wz C . Image denoising using deep CNN with batch renormalization[J]. Neural Networks, 2020, 121:461-473.
|
||||
|
||||
## 模型结构
|
||||
|
||||
BRDNet 包含上下两个分支。上分支仅仅包含残差学习与 BRN;下分支包含 BRN、残差学习以及扩张卷积。
|
||||
|
||||
上分支网络包含两种不同类型的层:Conv+BRN+ReLU 与 Conv。它的深度为 17,前 16 个层为 Conv+BRN+ReLU, 最后一层为 Conv。特征图通道数为 64,卷积核尺寸为 3.
|
||||
下分支网络同样包含 17 层,但其 1、9、16 层为 Conv+BRN+ReLU,2-8、10-15 为扩张卷积,最后一层为 Conv。卷积核尺寸为 3,通道数为 64.
|
||||
|
||||
两个分支结果经由 `Concat` 组合并经 Conv 得到噪声,最后采用原始输入减去该噪声即可得到清晰无噪图像。整个 BRDNet 合计包含 18 层,相对比较浅,不会导致梯度消失与梯度爆炸问题。
|
||||
|
||||
## 数据集
|
||||
|
||||
训练数据集:[color noisy](<https://pan.baidu.com/s/1cx3ymsWLIT-YIiJRBza24Q>)
|
||||
|
||||
去除高斯噪声时,数据集选用源于 Waterloo Exploration Database 的 3,859 张图片,并预处理为 50x50 的小图片,共计 1,166,393 张。
|
||||
|
||||
测试数据集:CBSD68, Kodak24, and McMaster
|
||||
|
||||
## 环境要求
|
||||
|
||||
- 硬件(Ascend/ModelArts)
|
||||
- 准备Ascend或ModelArts处理器搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.html)
|
||||
|
||||
## 快速入门
|
||||
|
||||
通过官方网站安装 MindSpore 后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
```shell
|
||||
#通过 python 命令行运行单卡训练脚本。 请注意 train_data 需要以"/"结尾
|
||||
python train.py \
|
||||
--train_data=xxx/dataset/waterloo5050step40colorimage/ \
|
||||
--sigma=15 \
|
||||
--channel=3 \
|
||||
--batch_size=32 \
|
||||
--lr=0.001 \
|
||||
--use_modelarts=0 \
|
||||
--output_path=./output/ \
|
||||
--is_distributed=0 \
|
||||
--epoch=50 > log.txt 2>&1 &
|
||||
|
||||
#通过 sh 命令启动单卡训练。(对 train_data 等参数的路径格式无要求,内部会自动转为绝对路径以及以"/"结尾)
|
||||
sh ./scripts/run_train.sh [train_code_path] [train_data] [batch_size] [sigma] [channel] [epoch] [lr]
|
||||
|
||||
#Ascend多卡训练。
|
||||
sh run_distribute_train.sh [train_code_path] [train_data] [batch_size] [sigma] [channel] [epoch] [lr] [rank_table_file_path]
|
||||
|
||||
# 通过 python 命令行运行推理脚本。请注意 test_dir 需要以"/"结尾;
|
||||
# pretrain_path 指 ckpt 所在目录,为了兼容 modelarts,将其拆分为了 “路径” 与 “文件名”
|
||||
python eval.py \
|
||||
--test_dir=xxx/Test/Kodak24/ \
|
||||
--sigma=15 \
|
||||
--channel=3 \
|
||||
--pretrain_path=xxx/ \
|
||||
--ckpt_name=channel_3_sigma_15_rank_0-50_227800.ckpt \
|
||||
--use_modelarts=0 \
|
||||
--output_path=./output/ \
|
||||
--is_distributed=0 > log.txt 2>&1 &
|
||||
|
||||
#通过 sh 命令启动推理。(对 test_dir 等参数的路径格式无要求,内部会自动转为绝对路径以及以"/"结尾)
|
||||
sh run_eval.sh [train_code_path] [test_dir] [sigma] [channel] [pretrain_path] [ckpt_name]
|
||||
```
|
||||
|
||||
Ascend训练:生成[RANK_TABLE_FILE](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)
|
||||
|
||||
## 脚本说明
|
||||
|
||||
### 脚本及样例代码
|
||||
|
||||
```shell
|
||||
├── model_zoo
|
||||
├── README.md // 所有模型的说明文件
|
||||
├── brdnet
|
||||
├── README.md // brdnet 的说明文件
|
||||
├── ascend310_infer // 310 推理代码目录(C++)
|
||||
│ ├──inc
|
||||
│ │ ├──utils.h // 工具包头文件
|
||||
│ ├──src
|
||||
│ │ ├──main.cc // 310推理代码
|
||||
│ │ ├──utils.cc // 工具包
|
||||
│ ├──build.sh // 代码编译脚本
|
||||
│ ├──CMakeLists.txt // 代码编译设置
|
||||
├── scripts
|
||||
│ ├──run_distribute_train.sh // Ascend 8卡训练脚本
|
||||
│ ├──run_eval.sh // 推理启动脚本
|
||||
│ ├──run_train.sh // 训练启动脚本
|
||||
│ ├──run_infer_310.sh // 启动310推理的脚本
|
||||
├── src
|
||||
│ ├──dataset.py // 数据集处理
|
||||
│ ├──distributed_sampler.py // 8卡并行时的数据集切分操作
|
||||
│ ├──logger.py // 日志打印文件
|
||||
│ ├──models.py // 模型结构
|
||||
├── export.py // 将权重文件导出为 MINDIR 等格式的脚本
|
||||
├── train.py // 训练脚本
|
||||
├── eval.py // 推理脚本
|
||||
├── cal_psnr.py // 310推理时计算最终PSNR值的脚本
|
||||
├── preprocess.py // 310推理时为测试图片添加噪声的脚本
|
||||
```
|
||||
|
||||
### 脚本参数
|
||||
|
||||
```shell
|
||||
train.py 中的主要参数如下:
|
||||
--batch_size: 批次大小
|
||||
--train_data: 训练数据集路径(必须以"/"结尾)。
|
||||
--sigma: 高斯噪声强度
|
||||
--channel: 训练类型(3:彩色图;1:灰度图)
|
||||
--epoch: 训练次数
|
||||
--lr: 初始学习率
|
||||
--save_every: 权重保存频率(每 N 个 epoch 保存一次)
|
||||
--pretrain: 预训练文件(接着该文件继续训练)
|
||||
--use_modelarts: 是否使用 modelarts(1 for True, 0 for False; 设置为 1 时将使用 moxing 从 obs 拷贝数据)
|
||||
--train_url: ( modelsarts 需要的参数,但因该名称存在歧义而在代码中未使用)
|
||||
--data_url: ( modelsarts 需要的参数,但因该名称存在歧义而在代码中未使用)
|
||||
--output_path: 日志等文件输出目录
|
||||
--outer_path: 输出到 obs 外部的目录(仅在 modelarts 上运行时有效)
|
||||
--device_target: 运行设备(默认 "Ascend")
|
||||
--is_distributed: 是否多卡运行
|
||||
--rank: Local rank of distributed. Default: 0
|
||||
--group_size: World size of device. Default: 1
|
||||
--is_save_on_master: 是否仅保存 0 卡上的运行结果
|
||||
--ckpt_save_max: ckpt 保存的最多文件数
|
||||
|
||||
eval.py 中的主要参数如下:
|
||||
--test_dir: 测试数据集路径(必须以"/"结尾)。
|
||||
--sigma: 高斯噪声强度
|
||||
--channel: 推理类型(3:彩色图;1:灰度图)
|
||||
--pretrain_path: 权重文件路径
|
||||
--ckpt_name: 权重文件名称
|
||||
--use_modelarts: 是否使用 modelarts(1 for True, 0 for False; 设置为 1 时将使用 moxing 从 obs 拷贝数据)
|
||||
--train_url: ( modelsarts 需要的参数,但因该名称存在歧义而在代码中未使用)
|
||||
--data_url: ( modelsarts 需要的参数,但因该名称存在歧义而在代码中未使用)
|
||||
--output_path: 日志等文件输出目录
|
||||
--outer_path: 输出到 obs 外部的目录(仅在 modelarts 上运行时有效)
|
||||
--device_target: 运行设备(默认 "Ascend")
|
||||
|
||||
export.py 中的主要参数如下:
|
||||
--batch_size: 批次大小
|
||||
--channel: 训练类型(3:彩色图;1:灰度图)
|
||||
--image_height: 图片高度
|
||||
--image_width: 图片宽度
|
||||
--ckpt_file: 权重文件路径
|
||||
--file_name: 权重文件名称
|
||||
--file_format: 待转文件格式,choices=["AIR", "ONNX", "MINDIR"]
|
||||
--device_target: 运行设备(默认 "Ascend")
|
||||
--device_id: 运行设备id
|
||||
|
||||
preprocess.py 中的主要参数如下:
|
||||
--out_dir: 保存噪声图片的路径
|
||||
--image_path: 测试图片路径(必须以"/"结尾)
|
||||
--channel: 图片通道(3:彩色图;1:灰度图)
|
||||
--sigma: 噪声强度
|
||||
|
||||
cal_psnr.py 中的主要参数如下:
|
||||
--image_path: 测试图片路径(必须以"/"结尾)
|
||||
--output_path: 模型推理完成后的结果保存路径,默认为 xx/scripts/result_Files/
|
||||
--image_width: 图片宽度
|
||||
--image_height: 图片高度
|
||||
--channel:图片通道(3:彩色图;1:灰度图)
|
||||
```
|
||||
|
||||
### 训练过程
|
||||
|
||||
#### 训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```shell
|
||||
#通过 python 命令行运行单卡训练脚本。 请注意 train_data 需要以"/"结尾
|
||||
python train.py \
|
||||
--train_data=xxx/dataset/waterloo5050step40colorimage/ \
|
||||
--sigma=15 \
|
||||
--channel=3 \
|
||||
--batch_size=32 \
|
||||
--lr=0.001 \
|
||||
--use_modelarts=0 \
|
||||
--output_path=./output/ \
|
||||
--is_distributed=0 \
|
||||
--epoch=50 > log.txt 2>&1 &
|
||||
|
||||
#通过 sh 命令启动单卡训练。(对 train_data 等参数的路径格式无要求,内部会自动转为绝对路径以及以"/"结尾)
|
||||
sh ./scripts/run_train.sh [train_code_path] [train_data] [batch_size] [sigma] [channel] [epoch] [lr]
|
||||
|
||||
#上述命令均会使脚本在后台运行,日志将输出到 log.txt,可通过查看该文件了解训练详情
|
||||
|
||||
#Ascend多卡训练(2、4、8卡配置请自行修改run_distribute_train.sh,默认8卡)
|
||||
sh run_distribute_train.sh [train_code_path] [train_data] [batch_size] [sigma] [channel] [epoch] [lr] [rank_table_file_path]
|
||||
```
|
||||
|
||||
注意:第一次运行时可能会较长时间停留在如下界面,这是因为当一个 epoch 运行完成后才会打印日志,请耐心等待。
|
||||
|
||||
单卡运行时第一个 epoch 预计耗时 20 ~ 30 分钟。
|
||||
|
||||
```python
|
||||
2021-05-16 20:12:17,888:INFO:Args:
|
||||
2021-05-16 20:12:17,888:INFO:--> batch_size: 32
|
||||
2021-05-16 20:12:17,888:INFO:--> train_data: ../dataset/waterloo5050step40colorimage/
|
||||
2021-05-16 20:12:17,889:INFO:--> sigma: 15
|
||||
2021-05-16 20:12:17,889:INFO:--> channel: 3
|
||||
2021-05-16 20:12:17,889:INFO:--> epoch: 50
|
||||
2021-05-16 20:12:17,889:INFO:--> lr: 0.001
|
||||
2021-05-16 20:12:17,889:INFO:--> save_every: 1
|
||||
2021-05-16 20:12:17,889:INFO:--> pretrain: None
|
||||
2021-05-16 20:12:17,889:INFO:--> use_modelarts: False
|
||||
2021-05-16 20:12:17,889:INFO:--> train_url: train_url/
|
||||
2021-05-16 20:12:17,889:INFO:--> data_url: data_url/
|
||||
2021-05-16 20:12:17,889:INFO:--> output_path: ./output/
|
||||
2021-05-16 20:12:17,889:INFO:--> outer_path: s3://output/
|
||||
2021-05-16 20:12:17,889:INFO:--> device_target: Ascend
|
||||
2021-05-16 20:12:17,890:INFO:--> is_distributed: 0
|
||||
2021-05-16 20:12:17,890:INFO:--> rank: 0
|
||||
2021-05-16 20:12:17,890:INFO:--> group_size: 1
|
||||
2021-05-16 20:12:17,890:INFO:--> is_save_on_master: 1
|
||||
2021-05-16 20:12:17,890:INFO:--> ckpt_save_max: 5
|
||||
2021-05-16 20:12:17,890:INFO:--> rank_save_ckpt_flag: 1
|
||||
2021-05-16 20:12:17,890:INFO:--> logger: <LOGGER BRDNet (NOTSET)>
|
||||
2021-05-16 20:12:17,890:INFO:
|
||||
```
|
||||
|
||||
训练完成后,您可以在 --output_path 参数指定的目录下找到保存的权重文件,训练过程中的部分 loss 收敛情况如下:
|
||||
|
||||
```python
|
||||
# grep "epoch time:" log.txt
|
||||
epoch time: 1197471.061 ms, per step time: 32.853 ms
|
||||
epoch time: 1136826.065 ms, per step time: 31.189 ms
|
||||
epoch time: 1136840.334 ms, per step time: 31.190 ms
|
||||
epoch time: 1136837.709 ms, per step time: 31.190 ms
|
||||
epoch time: 1137081.757 ms, per step time: 31.197 ms
|
||||
epoch time: 1136830.581 ms, per step time: 31.190 ms
|
||||
epoch time: 1136845.253 ms, per step time: 31.190 ms
|
||||
epoch time: 1136881.960 ms, per step time: 31.191 ms
|
||||
epoch time: 1136850.673 ms, per step time: 31.190 ms
|
||||
epoch: 10 step: 36449, loss is 103.104095
|
||||
epoch time: 1137098.407 ms, per step time: 31.197 ms
|
||||
epoch time: 1136794.613 ms, per step time: 31.189 ms
|
||||
epoch time: 1136742.922 ms, per step time: 31.187 ms
|
||||
epoch time: 1136842.009 ms, per step time: 31.190 ms
|
||||
epoch time: 1136792.705 ms, per step time: 31.189 ms
|
||||
epoch time: 1137056.362 ms, per step time: 31.196 ms
|
||||
epoch time: 1136863.373 ms, per step time: 31.191 ms
|
||||
epoch time: 1136842.938 ms, per step time: 31.190 ms
|
||||
epoch time: 1136839.011 ms, per step time: 31.190 ms
|
||||
epoch time: 1136879.794 ms, per step time: 31.191 ms
|
||||
epoch: 20 step: 36449, loss is 61.104546
|
||||
epoch time: 1137035.395 ms, per step time: 31.195 ms
|
||||
epoch time: 1136830.626 ms, per step time: 31.190 ms
|
||||
epoch time: 1136862.117 ms, per step time: 31.190 ms
|
||||
epoch time: 1136812.265 ms, per step time: 31.189 ms
|
||||
epoch time: 1136821.096 ms, per step time: 31.189 ms
|
||||
epoch time: 1137050.310 ms, per step time: 31.196 ms
|
||||
epoch time: 1136815.292 ms, per step time: 31.189 ms
|
||||
epoch time: 1136817.757 ms, per step time: 31.189 ms
|
||||
epoch time: 1136876.477 ms, per step time: 31.191 ms
|
||||
epoch time: 1136798.538 ms, per step time: 31.189 ms
|
||||
epoch: 30 step: 36449, loss is 116.179596
|
||||
epoch time: 1136972.930 ms, per step time: 31.194 ms
|
||||
epoch time: 1136825.174 ms, per step time: 31.189 ms
|
||||
epoch time: 1136798.900 ms, per step time: 31.189 ms
|
||||
epoch time: 1136828.101 ms, per step time: 31.190 ms
|
||||
epoch time: 1136862.983 ms, per step time: 31.191 ms
|
||||
epoch time: 1136989.445 ms, per step time: 31.194 ms
|
||||
epoch time: 1136688.820 ms, per step time: 31.186 ms
|
||||
epoch time: 1136858.111 ms, per step time: 31.190 ms
|
||||
epoch time: 1136822.853 ms, per step time: 31.189 ms
|
||||
epoch time: 1136782.455 ms, per step time: 31.188 ms
|
||||
epoch: 40 step: 36449, loss is 70.95368
|
||||
epoch time: 1137042.689 ms, per step time: 31.195 ms
|
||||
epoch time: 1136797.706 ms, per step time: 31.189 ms
|
||||
epoch time: 1136817.007 ms, per step time: 31.189 ms
|
||||
epoch time: 1136861.577 ms, per step time: 31.190 ms
|
||||
epoch time: 1136698.149 ms, per step time: 31.186 ms
|
||||
epoch time: 1137052.034 ms, per step time: 31.196 ms
|
||||
epoch time: 1136809.339 ms, per step time: 31.189 ms
|
||||
epoch time: 1136851.343 ms, per step time: 31.190 ms
|
||||
epoch time: 1136761.354 ms, per step time: 31.188 ms
|
||||
epoch time: 1136837.762 ms, per step time: 31.190 ms
|
||||
epoch: 50 step: 36449, loss is 87.13184
|
||||
epoch time: 1137022.554 ms, per step time: 31.195 ms
|
||||
2021-05-19 14:24:52,695:INFO:training finished....
|
||||
...
|
||||
```
|
||||
|
||||
8 卡并行时的 loss 收敛情况:
|
||||
|
||||
```python
|
||||
epoch time: 217708.130 ms, per step time: 47.785 ms
|
||||
epoch time: 144899.598 ms, per step time: 31.804 ms
|
||||
epoch time: 144736.054 ms, per step time: 31.768 ms
|
||||
epoch time: 144737.085 ms, per step time: 31.768 ms
|
||||
epoch time: 144738.102 ms, per step time: 31.769 ms
|
||||
epoch: 5 step: 4556, loss is 106.67432
|
||||
epoch time: 144905.830 ms, per step time: 31.805 ms
|
||||
epoch time: 144736.539 ms, per step time: 31.768 ms
|
||||
epoch time: 144734.210 ms, per step time: 31.768 ms
|
||||
epoch time: 144734.415 ms, per step time: 31.768 ms
|
||||
epoch time: 144736.405 ms, per step time: 31.768 ms
|
||||
epoch: 10 step: 4556, loss is 94.092865
|
||||
epoch time: 144921.081 ms, per step time: 31.809 ms
|
||||
epoch time: 144735.718 ms, per step time: 31.768 ms
|
||||
epoch time: 144737.036 ms, per step time: 31.768 ms
|
||||
epoch time: 144737.733 ms, per step time: 31.769 ms
|
||||
epoch time: 144738.251 ms, per step time: 31.769 ms
|
||||
epoch: 15 step: 4556, loss is 99.18075
|
||||
epoch time: 144921.945 ms, per step time: 31.809 ms
|
||||
epoch time: 144734.948 ms, per step time: 31.768 ms
|
||||
epoch time: 144735.662 ms, per step time: 31.768 ms
|
||||
epoch time: 144733.871 ms, per step time: 31.768 ms
|
||||
epoch time: 144734.722 ms, per step time: 31.768 ms
|
||||
epoch: 20 step: 4556, loss is 92.54497
|
||||
epoch time: 144907.430 ms, per step time: 31.806 ms
|
||||
epoch time: 144735.713 ms, per step time: 31.768 ms
|
||||
epoch time: 144733.781 ms, per step time: 31.768 ms
|
||||
epoch time: 144736.005 ms, per step time: 31.768 ms
|
||||
epoch time: 144734.331 ms, per step time: 31.768 ms
|
||||
epoch: 25 step: 4556, loss is 90.98991
|
||||
epoch time: 144911.420 ms, per step time: 31.807 ms
|
||||
epoch time: 144734.535 ms, per step time: 31.768 ms
|
||||
epoch time: 144734.851 ms, per step time: 31.768 ms
|
||||
epoch time: 144736.346 ms, per step time: 31.768 ms
|
||||
epoch time: 144734.939 ms, per step time: 31.768 ms
|
||||
epoch: 30 step: 4556, loss is 114.33954
|
||||
epoch time: 144915.434 ms, per step time: 31.808 ms
|
||||
epoch time: 144737.336 ms, per step time: 31.769 ms
|
||||
epoch time: 144733.943 ms, per step time: 31.768 ms
|
||||
epoch time: 144734.587 ms, per step time: 31.768 ms
|
||||
epoch time: 144735.043 ms, per step time: 31.768 ms
|
||||
epoch: 35 step: 4556, loss is 97.21166
|
||||
epoch time: 144912.719 ms, per step time: 31.807 ms
|
||||
epoch time: 144734.795 ms, per step time: 31.768 ms
|
||||
epoch time: 144733.824 ms, per step time: 31.768 ms
|
||||
epoch time: 144735.946 ms, per step time: 31.768 ms
|
||||
epoch time: 144734.930 ms, per step time: 31.768 ms
|
||||
epoch: 40 step: 4556, loss is 82.41978
|
||||
epoch time: 144901.017 ms, per step time: 31.804 ms
|
||||
epoch time: 144735.060 ms, per step time: 31.768 ms
|
||||
epoch time: 144733.657 ms, per step time: 31.768 ms
|
||||
epoch time: 144732.592 ms, per step time: 31.767 ms
|
||||
epoch time: 144731.292 ms, per step time: 31.767 ms
|
||||
epoch: 45 step: 4556, loss is 77.92129
|
||||
epoch time: 144909.250 ms, per step time: 31.806 ms
|
||||
epoch time: 144732.944 ms, per step time: 31.768 ms
|
||||
epoch time: 144733.161 ms, per step time: 31.768 ms
|
||||
epoch time: 144732.912 ms, per step time: 31.768 ms
|
||||
epoch time: 144733.709 ms, per step time: 31.768 ms
|
||||
epoch: 50 step: 4556, loss is 85.499596
|
||||
2021-05-19 02:44:44,219:INFO:training finished....
|
||||
```
|
||||
|
||||
### 评估过程
|
||||
|
||||
#### 评估
|
||||
|
||||
在运行以下命令之前,请检查用于推理评估的权重文件路径是否正确。
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```python
|
||||
#通过 python 命令启动评估,请注意 test_dir 需要以"/" 结尾
|
||||
python eval.py \
|
||||
--test_dir=./Test/CBSD68/ \
|
||||
--sigma=15 \
|
||||
--channel=3 \
|
||||
--pretrain_path=./ \
|
||||
--ckpt_name=channel_3_sigma_15_rank_0-50_227800.ckpt \
|
||||
--use_modelarts=0 \
|
||||
--output_path=./output/ \
|
||||
--is_distributed=0 > log.txt 2>&1 &
|
||||
|
||||
#通过 sh 命令启动评估 (对 test_dir 等参数的路径格式无要求,内部会自动转为绝对路径以及以"/"结尾)
|
||||
sh run_eval.sh [train_code_path] [test_dir] [sigma] [channel] [pretrain_path] [ckpt_name]
|
||||
```
|
||||
|
||||
```python
|
||||
2021-05-17 13:40:45,909:INFO:Start to test on ./Test/CBSD68/
|
||||
2021-05-17 13:40:46,447:INFO:load test weights from channel_3_sigma_15_rank_0-50_227800.ckpt
|
||||
2021-05-17 13:41:52,164:INFO:Before denoise: Average PSNR_b = 24.62, SSIM_b = 0.56;After denoise: Average PSNR = 34.05, SSIM = 0.94
|
||||
2021-05-17 13:41:52,207:INFO:testing finished....
|
||||
```
|
||||
|
||||
评估完成后,您可以在 --output_path 参数指定的目录下找到 加高斯噪声后的图片和经过模型去除高斯噪声后的图片,图片命名方式代表了处理结果。例如 00001_sigma15_psnr24.62.bmp 是加噪声后的图片(加噪声后 psnr=24.62),00001_psnr31.18.bmp 是去噪声后的图片(去噪后 psnr=31.18)。
|
||||
|
||||
另外,该文件夹下的 metrics.csv 文件详细记录了对每张测试图片的处理结果,如下所示,psnr_b 是去噪前的 psnr 值,psnr 是去噪后的psnr 值;ssim 指标同理。
|
||||
|
||||
| | name | psnr_b | psnr | ssim_b | ssim |
|
||||
| ---- | ------- | ----------- | ----------- | ----------- | ----------- |
|
||||
| 0 | 1 | 24.61875916 | 31.17827606 | 0.716650724 | 0.910416007 |
|
||||
| 1 | 2 | 24.61875916 | 35.12858963 | 0.457143694 | 0.995960176 |
|
||||
| 2 | 3 | 24.61875916 | 34.90437698 | 0.465185702 | 0.935821533 |
|
||||
| 3 | 4 | 24.61875916 | 35.59785461 | 0.49323535 | 0.941600204 |
|
||||
| 4 | 5 | 24.61875916 | 32.9185257 | 0.605194688 | 0.958840668 |
|
||||
| 5 | 6 | 24.61875916 | 37.29947662 | 0.368243992 | 0.962466478 |
|
||||
| 6 | 7 | 24.61875916 | 33.59238052 | 0.622622728 | 0.930195987 |
|
||||
| 7 | 8 | 24.61875916 | 31.76290894 | 0.680295587 | 0.918859363 |
|
||||
| 8 | 9 | 24.61875916 | 34.13358688 | 0.55876708 | 0.939204693 |
|
||||
| 9 | 10 | 24.61875916 | 34.49848557 | 0.503289104 | 0.928179622 |
|
||||
| 10 | 11 | 24.61875916 | 34.38597107 | 0.656857133 | 0.961226702 |
|
||||
| 11 | 12 | 24.61875916 | 32.75747299 | 0.627940595 | 0.910765707 |
|
||||
| 12 | 13 | 24.61875916 | 34.52487564 | 0.54259634 | 0.936489582 |
|
||||
| 13 | 14 | 24.61875916 | 35.40441132 | 0.44824928 | 0.93462956 |
|
||||
| 14 | 15 | 24.61875916 | 32.72385788 | 0.61768961 | 0.91652298 |
|
||||
| 15 | 16 | 24.61875916 | 33.59120178 | 0.703662276 | 0.948698342 |
|
||||
| 16 | 17 | 24.61875916 | 36.85597229 | 0.365240872 | 0.940135658 |
|
||||
| 17 | 18 | 24.61875916 | 37.23021317 | 0.366332233 | 0.953653395 |
|
||||
| 18 | 19 | 24.61875916 | 33.49061584 | 0.546713233 | 0.928890586 |
|
||||
| 19 | 20 | 24.61875916 | 33.98015213 | 0.463814735 | 0.938398063 |
|
||||
| 20 | 21 | 24.61875916 | 32.15977859 | 0.714740098 | 0.945747674 |
|
||||
| 21 | 22 | 24.61875916 | 32.39984512 | 0.716880679 | 0.930429876 |
|
||||
| 22 | 23 | 24.61875916 | 34.22258759 | 0.569748521 | 0.945626318 |
|
||||
| 23 | 24 | 24.61875916 | 33.974823 | 0.603115499 | 0.941333234 |
|
||||
| 24 | 25 | 24.61875916 | 34.87198639 | 0.486003697 | 0.966141582 |
|
||||
| 25 | 26 | 24.61875916 | 33.2747879 | 0.593207896 | 0.917522907 |
|
||||
| 26 | 27 | 24.61875916 | 34.67901611 | 0.504613101 | 0.921615481 |
|
||||
| 27 | 28 | 24.61875916 | 37.70562363 | 0.331322074 | 0.977765024 |
|
||||
| 28 | 29 | 24.61875916 | 31.08887672 | 0.759773433 | 0.958483219 |
|
||||
| 29 | 30 | 24.61875916 | 34.48878479 | 0.502000451 | 0.915705442 |
|
||||
| 30 | 31 | 24.61875916 | 30.5480938 | 0.836367846 | 0.949165165 |
|
||||
| 31 | 32 | 24.61875916 | 32.08041382 | 0.745214283 | 0.941719413 |
|
||||
| 32 | 33 | 24.61875916 | 33.65553284 | 0.556162357 | 0.963605523 |
|
||||
| 33 | 34 | 24.61875916 | 36.87154388 | 0.384932011 | 0.93150568 |
|
||||
| 34 | 35 | 24.61875916 | 33.03263474 | 0.586027861 | 0.924151421 |
|
||||
| 35 | 36 | 24.61875916 | 31.80633736 | 0.572878599 | 0.84426564 |
|
||||
| 36 | 37 | 24.61875916 | 33.26797485 | 0.526310742 | 0.938789487 |
|
||||
| 37 | 38 | 24.61875916 | 33.71062469 | 0.554955184 | 0.914420724 |
|
||||
| 38 | 39 | 24.61875916 | 37.3455925 | 0.461908668 | 0.956513464 |
|
||||
| 39 | 40 | 24.61875916 | 33.92232895 | 0.554454744 | 0.89727515 |
|
||||
| 40 | 41 | 24.61875916 | 33.05244827 | 0.590977669 | 0.931611121 |
|
||||
| 41 | 42 | 24.61875916 | 34.60203552 | 0.492371827 | 0.927084684 |
|
||||
| 42 | 43 | 24.61875916 | 35.20042419 | 0.535991669 | 0.949365258 |
|
||||
| 43 | 44 | 24.61875916 | 33.47367096 | 0.614959836 | 0.954348624 |
|
||||
| 44 | 45 | 24.61875916 | 37.65309143 | 0.363631308 | 0.944297135 |
|
||||
| 45 | 46 | 24.61875916 | 31.95152092 | 0.709732175 | 0.924522877 |
|
||||
| 46 | 47 | 24.61875916 | 31.9910202 | 0.70427531 | 0.932488263 |
|
||||
| 47 | 48 | 24.61875916 | 34.96608353 | 0.585813344 | 0.969006479 |
|
||||
| 48 | 49 | 24.61875916 | 35.39241409 | 0.388898522 | 0.923762918 |
|
||||
| 49 | 50 | 24.61875916 | 32.11050415 | 0.653521299 | 0.938310325 |
|
||||
| 50 | 51 | 24.61875916 | 33.54981995 | 0.594990134 | 0.927819192 |
|
||||
| 51 | 52 | 24.61875916 | 35.79096603 | 0.371685684 | 0.922166049 |
|
||||
| 52 | 53 | 24.61875916 | 35.10015869 | 0.410564244 | 0.895557165 |
|
||||
| 53 | 54 | 24.61875916 | 34.12319565 | 0.591762364 | 0.925524533 |
|
||||
| 54 | 55 | 24.61875916 | 32.79537964 | 0.653338313 | 0.92444253 |
|
||||
| 55 | 56 | 24.61875916 | 29.90909004 | 0.826190114 | 0.943322361 |
|
||||
| 56 | 57 | 24.61875916 | 33.23035812 | 0.527200282 | 0.943572938 |
|
||||
| 57 | 58 | 24.61875916 | 34.56663132 | 0.409658968 | 0.898451686 |
|
||||
| 58 | 59 | 24.61875916 | 34.43690109 | 0.454208463 | 0.904649734 |
|
||||
| 59 | 60 | 24.61875916 | 35.0402565 | 0.409306735 | 0.902388573 |
|
||||
| 60 | 61 | 24.61875916 | 34.91940308 | 0.443635762 | 0.911728501 |
|
||||
| 61 | 62 | 24.61875916 | 38.25325394 | 0.42568323 | 0.965163887 |
|
||||
| 62 | 63 | 24.61875916 | 32.07671356 | 0.727443576 | 0.94306612 |
|
||||
| 63 | 64 | 24.61875916 | 31.72690964 | 0.671929657 | 0.902075231 |
|
||||
| 64 | 65 | 24.61875916 | 33.47768402 | 0.533677042 | 0.922906399 |
|
||||
| 65 | 66 | 24.61875916 | 42.14694977 | 0.266868591 | 0.991976082 |
|
||||
| 66 | 67 | 24.61875916 | 30.81770706 | 0.84768647 | 0.957114518 |
|
||||
| 67 | 68 | 24.61875916 | 32.24455261 | 0.623004258 | 0.97843051 |
|
||||
| 68 | Average | 24.61875916 | 34.05390495 | 0.555872787 | 0.935704286 |
|
||||
|
||||
### 310推理
|
||||
|
||||
- 在 Ascend 310 处理器环境运行
|
||||
|
||||
```python
|
||||
#通过 sh 命令启动推理
|
||||
sh run_infer_310.sh [model_path] [data_path] [noise_image_path] [sigma] [channel] [device_id]
|
||||
#上述命令将完成推理所需的全部工作。执行完成后,将产生 preprocess.log、infer.log、psnr.log 三个日志文件。
|
||||
#如果您需要单独执行各部分代码,可以参照 run_infer_310.sh 内的流程分别进行编译、图片预处理、推理和 PSNR 计算,请注意核对各部分所需参数!
|
||||
```
|
||||
|
||||
## 模型描述
|
||||
|
||||
### 性能
|
||||
|
||||
#### 评估性能
|
||||
|
||||
BRDNet on “waterloo5050step40colorimage”
|
||||
|
||||
| Parameters | BRDNet |
|
||||
| -------------------------- | ------------------------------------------------------------------------------ |
|
||||
| Resource | Ascend 910 ;CPU 2.60GHz,192cores; Memory, 755G |
|
||||
| uploaded Date | 5/20/2021 (month/day/year) |
|
||||
| MindSpore Version | master |
|
||||
| Dataset | waterloo5050step40colorimage |
|
||||
| Training Parameters | epoch=50, batch_size=32, lr=0.001 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | MSELoss(reduction='sum') |
|
||||
| outputs | denoised images |
|
||||
| Loss | 80.839773 |
|
||||
| Speed | 8p about 7000FPS to 7400FPS |
|
||||
| Total time | 8p about 2h 14min |
|
||||
| Checkpoint for Fine tuning | 8p: 13.68MB , 1p: 19.76MB (.ckpt file) |
|
||||
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/brdnet |
|
||||
|
||||
## 随机情况说明
|
||||
|
||||
train.py中设置了随机种子。
|
||||
|
||||
## ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,14 @@
|
|||
cmake_minimum_required(VERSION 3.14.1)
|
||||
project(Ascend310Infer)
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||
option(MINDSPORE_PATH "mindspore install path" "")
|
||||
include_directories(${MINDSPORE_PATH})
|
||||
include_directories(${MINDSPORE_PATH}/include)
|
||||
include_directories(${PROJECT_SRC_ROOT})
|
||||
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||
|
||||
add_executable(main src/main.cc src/utils.cc)
|
||||
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ -d out ]; then
|
||||
rm -rf out
|
||||
fi
|
||||
|
||||
mkdir out
|
||||
cd out || exit
|
||||
|
||||
if [ -f "Makefile" ]; then
|
||||
make clean
|
||||
fi
|
||||
|
||||
cmake .. \
|
||||
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||
make
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <dirent.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||
DIR *OpenDir(std::string_view dirName);
|
||||
std::string RealPath(std::string_view path);
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
|
||||
#endif
|
|
@ -0,0 +1,154 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <sys/time.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <dirent.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iosfwd>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <random>
|
||||
#include <ctime>
|
||||
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/serialization.h"
|
||||
#include "include/minddata/dataset/include/vision_ascend.h"
|
||||
#include "include/minddata/dataset/include/execute.h"
|
||||
#include "include/minddata/dataset/include/transforms.h"
|
||||
#include "include/minddata/dataset/include/constants.h"
|
||||
#include "include/minddata/dataset/include/vision.h"
|
||||
#include "inc/utils.h"
|
||||
|
||||
using mindspore::Serialization;
|
||||
using mindspore::Model;
|
||||
using mindspore::Context;
|
||||
using mindspore::Status;
|
||||
using mindspore::ModelType;
|
||||
using mindspore::Graph;
|
||||
using mindspore::GraphCell;
|
||||
using mindspore::kSuccess;
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::DataType;
|
||||
|
||||
using mindspore::dataset::Execute;
|
||||
using mindspore::dataset::TensorTransform;
|
||||
using mindspore::dataset::vision::Decode;
|
||||
using mindspore::dataset::vision::Resize;
|
||||
using mindspore::dataset::vision::HWC2CHW;
|
||||
using mindspore::dataset::vision::Normalize;
|
||||
using mindspore::dataset::transforms::TypeCast;
|
||||
|
||||
DEFINE_string(model_path, "", "model path");
|
||||
DEFINE_string(dataset_path, "", "dataset path");
|
||||
DEFINE_int32(input_width, 500, "input width");
|
||||
DEFINE_int32(input_height, 500, "inputheight");
|
||||
DEFINE_int32(device_id, 0, "device id");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
if (RealPath(FLAGS_model_path).empty()) {
|
||||
std::cout << "Invalid mindir" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto context = std::make_shared<Context>();
|
||||
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||
ascend310_info->SetDeviceID(FLAGS_device_id);
|
||||
context->MutableDeviceInfo().push_back(ascend310_info);
|
||||
|
||||
Graph graph;
|
||||
Status ret = Serialization::Load(FLAGS_model_path, ModelType::kMindIR, &graph);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "Load model failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
Model model;
|
||||
ret = model.Build(GraphCell(graph), context);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Build failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> modelInputs = model.GetInputs();
|
||||
|
||||
auto all_files = GetAllFiles(FLAGS_dataset_path);
|
||||
if (all_files.empty()) {
|
||||
std::cout << "ERROR: no input data." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::map<double, double> costTime_map;
|
||||
size_t size = all_files.size();
|
||||
// Define transform
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
struct timeval start;
|
||||
struct timeval end;
|
||||
double startTime_ms;
|
||||
double endTime_ms;
|
||||
|
||||
std::vector<MSTensor> inputs;
|
||||
std::vector<MSTensor> outputs;
|
||||
|
||||
std::cout << "Start predict input files:" << all_files[i] << std::endl;
|
||||
|
||||
auto image = ReadFileToTensor(all_files[i]);
|
||||
|
||||
inputs.emplace_back(modelInputs[0].Name(), modelInputs[0].DataType(), modelInputs[0].Shape(),
|
||||
image.Data().get(), image.DataSize());
|
||||
|
||||
gettimeofday(&start, NULL);
|
||||
ret = model.Predict(inputs, &outputs);
|
||||
gettimeofday(&end, NULL);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "Predict " << all_files[i] << " failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
startTime_ms = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||
endTime_ms = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||
costTime_map.insert(std::pair<double, double>(startTime_ms, endTime_ms));
|
||||
WriteResult(all_files[i], outputs);
|
||||
}
|
||||
|
||||
double average = 0.0;
|
||||
int infer_cnt = 0;
|
||||
|
||||
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||
double diff = 0.0;
|
||||
diff = iter->second - iter->first;
|
||||
average += diff;
|
||||
infer_cnt++;
|
||||
}
|
||||
|
||||
average = average / infer_cnt;
|
||||
std::stringstream timeCost;
|
||||
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << infer_cnt << std::endl;
|
||||
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << infer_cnt << std::endl;
|
||||
std::string file_name = "./time_Result" + std::string("/test_perform_static.txt");
|
||||
std::ofstream file_stream(file_name.c_str(), std::ios::trunc);
|
||||
file_stream << timeCost.str();
|
||||
file_stream.close();
|
||||
costTime_map.clear();
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,127 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/utils.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::DataType;
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||
struct dirent *filename;
|
||||
DIR *dir = OpenDir(dirName);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
std::vector<std::string> res;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string dName = std::string(filename->d_name);
|
||||
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||
continue;
|
||||
}
|
||||
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||
}
|
||||
std::sort(res.begin(), res.end());
|
||||
for (auto &f : res) {
|
||||
std::cout << "image file: " << f << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
|
||||
std::string homePath = "./result_Files";
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
size_t outputSize;
|
||||
std::shared_ptr<const void> netOutput = outputs[i].Data();
|
||||
outputSize = outputs[i].DataSize();
|
||||
int pos = imageFile.rfind('/');
|
||||
std::string fileName(imageFile, pos + 1);
|
||||
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
|
||||
std::string outFileName = homePath + "/" + fileName;
|
||||
FILE *outputFile = fopen(outFileName.c_str(), "wb");
|
||||
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
|
||||
fclose(outputFile);
|
||||
outputFile = nullptr;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
std::cout << "Pointer file is nullptr" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cout << "File: " << file << "open failed" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
|
||||
DIR *OpenDir(std::string_view dirName) {
|
||||
if (dirName.empty()) {
|
||||
std::cout << " dirName is null ! " << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::string realPath = RealPath(dirName);
|
||||
struct stat s;
|
||||
lstat(realPath.c_str(), &s);
|
||||
if (!S_ISDIR(s.st_mode)) {
|
||||
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
DIR *dir = opendir(realPath.c_str());
|
||||
if (dir == nullptr) {
|
||||
std::cout << "Can not open dir " << dirName << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||
return dir;
|
||||
}
|
||||
|
||||
std::string RealPath(std::string_view path) {
|
||||
char realPathMem[PATH_MAX] = {0};
|
||||
char *realPathRet = nullptr;
|
||||
realPathRet = realpath(path.data(), realPathMem);
|
||||
if (realPathRet == nullptr) {
|
||||
std::cout << "File: " << path << " is not exist.";
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string realPath(realPathMem);
|
||||
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||
return realPath;
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import glob
|
||||
import math
|
||||
import numpy as np
|
||||
import PIL.Image as Image
|
||||
|
||||
## Params
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--image_path', default='./Test/Kodak24/'
|
||||
, type=str, help='directory of test dataset')
|
||||
parser.add_argument('--output_path', default='', type=str,
|
||||
help='path of the predict files that generated by the model')
|
||||
parser.add_argument('--image_width', default=500, type=int, help='image_width')
|
||||
parser.add_argument('--image_height', default=500, type=int, help='image_height')
|
||||
parser.add_argument('--channel', default=3, type=int
|
||||
, help='image channel, 3 for color, 1 for gray')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def calculate_psnr(image1, image2):
|
||||
image1 = np.float64(image1)
|
||||
image2 = np.float64(image2)
|
||||
diff = image1 - image2
|
||||
diff = diff.flatten('C')
|
||||
rmse = math.sqrt(np.mean(diff**2.))
|
||||
return 20*math.log10(1.0/rmse)
|
||||
|
||||
def cal_psnr(output_path, image_path):
|
||||
file_list = glob.glob(image_path+'*') # image_path must end by '/'
|
||||
psnr = [] #after denoise
|
||||
|
||||
start_time = time.time()
|
||||
for file in file_list:
|
||||
filename = file.split('/')[-1].split('.')[0] # get the name of image file
|
||||
# read image
|
||||
if args.channel == 3:
|
||||
img_clean = np.array(Image.open(file), dtype='float32') / 255.0
|
||||
else:
|
||||
img_clean = np.expand_dims(np.array(Image.open(file).convert('L'), dtype='float32') / 255.0, axis=2)
|
||||
|
||||
result_file = os.path.join(output_path, filename+"_noise_0.bin")
|
||||
y_predict = np.fromfile(result_file, dtype=np.float32)
|
||||
y_predict = y_predict.reshape(args.channel, args.image_height, args.image_width)
|
||||
img_out = y_predict.transpose((1, 2, 0))#HWC
|
||||
img_out = np.clip(img_out, 0, 1)
|
||||
psnr_denoised = calculate_psnr(img_clean, img_out)
|
||||
psnr.append(psnr_denoised)
|
||||
print(filename, ": psnr_denoised: ", " ", psnr_denoised)
|
||||
|
||||
psnr_avg = sum(psnr)/len(psnr)
|
||||
print("Average PSNR:", psnr_avg)
|
||||
print("Testing finished....")
|
||||
time_used = time.time() - start_time
|
||||
print("Time cost:"+str(time_used)+" seconds!")
|
||||
|
||||
if __name__ == '__main__':
|
||||
cal_psnr(args.output_path, args.image_path)
|
|
@ -0,0 +1,193 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import datetime
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import glob
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import PIL.Image as Image
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.logger import get_logger
|
||||
from src.models import BRDNet
|
||||
|
||||
## Params
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--test_dir', default='./Test/Kodak24/'
|
||||
, type=str, help='directory of test dataset')
|
||||
parser.add_argument('--sigma', default=15, type=int, help='noise level')
|
||||
parser.add_argument('--channel', default=3, type=int
|
||||
, help='image channel, 3 for color, 1 for gray')
|
||||
parser.add_argument('--pretrain_path', default=None, type=str, help='path of pre-trained model')
|
||||
parser.add_argument('--ckpt_name', default=None, type=str, help='ckpt_name')
|
||||
parser.add_argument('--use_modelarts', type=int, default=0
|
||||
, help='1 for True, 0 for False;when set True, we should load dataset from obs with moxing')
|
||||
parser.add_argument('--train_url', type=str, default='train_url/'
|
||||
, help='needed by modelarts, but we donot use it because the name is ambiguous')
|
||||
parser.add_argument('--data_url', type=str, default='data_url/'
|
||||
, help='needed by modelarts, but we donot use it because the name is ambiguous')
|
||||
parser.add_argument('--output_path', type=str, default='./output/'
|
||||
, help='output_path,when use_modelarts is set True, it will be cache/output/')
|
||||
parser.add_argument('--outer_path', type=str, default='s3://output/'
|
||||
, help='obs path,to store e.g ckpt files ')
|
||||
|
||||
parser.add_argument('--device_target', type=str, default='Ascend',
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
|
||||
set_seed(1)
|
||||
|
||||
args = parser.parse_args()
|
||||
save_dir = os.path.join(args.output_path, 'sigma_' + str(args.sigma) + \
|
||||
'_' + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
|
||||
if not args.use_modelarts and not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
def test(model_path):
|
||||
args.logger.info('Start to test on {}'.format(args.test_dir))
|
||||
out_dir = os.path.join(save_dir, args.test_dir.split('/')[-2]) # args.test_dir must end by '/'
|
||||
if not args.use_modelarts and not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir)
|
||||
|
||||
model = BRDNet(args.channel)
|
||||
args.logger.info('load test weights from '+str(model_path))
|
||||
load_param_into_net(model, load_checkpoint(model_path))
|
||||
name = []
|
||||
psnr = [] #after denoise
|
||||
ssim = [] #after denoise
|
||||
psnr_b = [] #before denoise
|
||||
ssim_b = [] #before denoise
|
||||
|
||||
if args.use_modelarts:
|
||||
args.logger.info("copying test dataset from obs to cache....")
|
||||
mox.file.copy_parallel(args.test_dir, 'cache/test')
|
||||
args.logger.info("copying test dataset finished....")
|
||||
args.test_dir = 'cache/test/'
|
||||
|
||||
file_list = glob.glob(args.test_dir+'*') # args.test_dir must end by '/'
|
||||
model.set_train(False)
|
||||
|
||||
cast = P.Cast()
|
||||
transpose = P.Transpose()
|
||||
expand_dims = P.ExpandDims()
|
||||
compare_psnr = nn.PSNR()
|
||||
compare_ssim = nn.SSIM()
|
||||
|
||||
args.logger.info("start testing....")
|
||||
start_time = time.time()
|
||||
for file in file_list:
|
||||
suffix = file.split('.')[-1]
|
||||
# read image
|
||||
if args.channel == 3:
|
||||
img_clean = np.array(Image.open(file), dtype='float32') / 255.0
|
||||
else:
|
||||
img_clean = np.expand_dims(np.array(Image.open(file).convert('L'), dtype='float32') / 255.0, axis=2)
|
||||
|
||||
np.random.seed(0) #obtain the same random data when it is in the test phase
|
||||
img_test = img_clean + np.random.normal(0, args.sigma/255.0, img_clean.shape)
|
||||
|
||||
img_clean = Tensor(img_clean, mindspore.float32) #HWC
|
||||
img_test = Tensor(img_test, mindspore.float32) #HWC
|
||||
|
||||
# predict
|
||||
img_clean = expand_dims(transpose(img_clean, (2, 0, 1)), 0)#NCHW
|
||||
img_test = expand_dims(transpose(img_test, (2, 0, 1)), 0)#NCHW
|
||||
|
||||
y_predict = model(img_test) #NCHW
|
||||
|
||||
# calculate numeric metrics
|
||||
|
||||
img_out = C.clip_by_value(y_predict, 0, 1)
|
||||
|
||||
psnr_noise, psnr_denoised = compare_psnr(img_clean, img_test), compare_psnr(img_clean, img_out)
|
||||
ssim_noise, ssim_denoised = compare_ssim(img_clean, img_test), compare_ssim(img_clean, img_out)
|
||||
|
||||
|
||||
psnr.append(psnr_denoised.asnumpy()[0])
|
||||
ssim.append(ssim_denoised.asnumpy()[0])
|
||||
psnr_b.append(psnr_noise.asnumpy()[0])
|
||||
ssim_b.append(ssim_noise.asnumpy()[0])
|
||||
|
||||
# save images
|
||||
filename = file.split('/')[-1].split('.')[0] # get the name of image file
|
||||
name.append(filename)
|
||||
|
||||
if not args.use_modelarts:
|
||||
# inner the operation 'Image.save', it will first check the file \
|
||||
# existence of same name, that is not allowed on modelarts
|
||||
img_test = cast(img_test*255, mindspore.uint8).asnumpy()
|
||||
img_test = img_test.squeeze(0).transpose((1, 2, 0)) #turn into HWC to save as an image
|
||||
img_test = Image.fromarray(img_test)
|
||||
img_test.save(os.path.join(out_dir, filename+'_sigma'+'{}_psnr{:.2f}.'\
|
||||
.format(args.sigma, psnr_noise.asnumpy()[0])+str(suffix)))
|
||||
img_out = cast(img_out*255, mindspore.uint8).asnumpy()
|
||||
img_out = img_out.squeeze(0).transpose((1, 2, 0)) #turn into HWC to save as an image
|
||||
img_out = Image.fromarray(img_out)
|
||||
img_out.save(os.path.join(out_dir, filename+'_psnr{:.2f}.'.format(psnr_denoised.asnumpy()[0])+str(suffix)))
|
||||
|
||||
|
||||
psnr_avg = sum(psnr)/len(psnr)
|
||||
ssim_avg = sum(ssim)/len(ssim)
|
||||
psnr_avg_b = sum(psnr_b)/len(psnr_b)
|
||||
ssim_avg_b = sum(ssim_b)/len(ssim_b)
|
||||
name.append('Average')
|
||||
psnr.append(psnr_avg)
|
||||
ssim.append(ssim_avg)
|
||||
psnr_b.append(psnr_avg_b)
|
||||
ssim_b.append(ssim_avg_b)
|
||||
args.logger.info('Before denoise: Average PSNR_b = {0:.2f}, \
|
||||
SSIM_b = {1:.2f};After denoise: Average PSNR = {2:.2f}, SSIM = {3:.2f}'\
|
||||
.format(psnr_avg_b, ssim_avg_b, psnr_avg, ssim_avg))
|
||||
args.logger.info("testing finished....")
|
||||
time_used = time.time() - start_time
|
||||
args.logger.info("time cost:"+str(time_used)+" seconds!")
|
||||
if not args.use_modelarts:
|
||||
pd.DataFrame({'name': np.array(name), 'psnr_b': np.array(psnr_b), \
|
||||
'psnr': np.array(psnr), 'ssim_b': np.array(ssim_b), \
|
||||
'ssim': np.array(ssim)}).to_csv(out_dir+'/metrics.csv', index=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.device_target, device_id=device_id, save_graphs=False)
|
||||
|
||||
args.logger = get_logger(save_dir, "BRDNet", 0)
|
||||
args.logger.save_args(args)
|
||||
|
||||
if args.use_modelarts:
|
||||
import moxing as mox
|
||||
args.logger.info("copying test weights from obs to cache....")
|
||||
mox.file.copy_parallel(args.pretrain_path, 'cache/weight')
|
||||
args.logger.info("copying test weights finished....")
|
||||
args.pretrain_path = 'cache/weight/'
|
||||
|
||||
test(os.path.join(args.pretrain_path, args.ckpt_name))
|
||||
|
||||
if args.use_modelarts:
|
||||
args.logger.info("copying files from cache to obs....")
|
||||
mox.file.copy_parallel(save_dir, args.outer_path)
|
||||
args.logger.info("copying finished....")
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.models import BRDNet
|
||||
|
||||
|
||||
## Params
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--batch_size', default=32, type=int, help='batch size')
|
||||
parser.add_argument('--channel', default=3, type=int
|
||||
, help='image channel, 3 for color, 1 for gray')
|
||||
parser.add_argument("--image_height", type=int, default=50, help="Image height.")
|
||||
parser.add_argument("--image_width", type=int, default=50, help="Image width.")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="brdnet", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument('--device_target', type=str, default='Ascend'
|
||||
, help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
net = BRDNet(args_opt.channel)
|
||||
|
||||
param_dict = load_checkpoint(args_opt.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.zeros([args_opt.batch_size, args_opt.channel, \
|
||||
args_opt.image_height, args_opt.image_width]), ms.float32)
|
||||
export(net, input_arr, file_name=args_opt.file_name, file_format=args_opt.file_format)
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
import PIL.Image as Image
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--out_dir', type=str, required=True,
|
||||
help='directory to store the image with noise')
|
||||
parser.add_argument('--image_path', type=str, required=True,
|
||||
help='directory of image to add noise')
|
||||
parser.add_argument('--channel', type=int, default=3
|
||||
, help='image channel, 3 for color, 1 for gray')
|
||||
parser.add_argument('--sigma', type=int, default=15, help='level of noise')
|
||||
args = parser.parse_args()
|
||||
|
||||
def add_noise(out_dir, image_path, channel, sigma):
|
||||
file_list = glob.glob(image_path+'*') # image_path must end by '/'
|
||||
if not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir)
|
||||
|
||||
for file in file_list:
|
||||
print("Adding noise to: ", file)
|
||||
# read image
|
||||
if channel == 3:
|
||||
img_clean = np.array(Image.open(file), dtype='float32') / 255.0
|
||||
else:
|
||||
img_clean = np.expand_dims(np.array(Image.open(file).convert('L'), dtype='float32') / 255.0, axis=2)
|
||||
|
||||
np.random.seed(0) #obtain the same random data when it is in the test phase
|
||||
img_test = img_clean + np.random.normal(0, sigma/255.0, img_clean.shape).astype(np.float32)#HWC
|
||||
img_test = np.expand_dims(img_test.transpose((2, 0, 1)), 0)#NCHW
|
||||
#img_test = np.clip(img_test, 0, 1)
|
||||
|
||||
filename = file.split('/')[-1].split('.')[0] # get the name of image file
|
||||
img_test.tofile(os.path.join(out_dir, filename+'_noise.bin'))
|
||||
|
||||
if __name__ == "__main__":
|
||||
add_noise(args.out_dir, args.image_path, args.channel, args.sigma)
|
|
@ -0,0 +1,83 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 8 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [train_code_path] [train_data]" \
|
||||
"[batch_size] [sigma] [channel] [epoch] [lr] [rank_table_file_path]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)/"
|
||||
fi
|
||||
}
|
||||
|
||||
get_real_path_name() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
train_code_path=$(get_real_path $1)
|
||||
echo $train_code_path
|
||||
|
||||
if [ ! -d $train_code_path ]
|
||||
then
|
||||
echo "error: train_code_path=$train_code_path is not a dictionary."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
train_data=$(get_real_path $2)
|
||||
echo $train_data
|
||||
|
||||
if [ ! -d $train_data ]
|
||||
then
|
||||
echo "error: train_data=$train_data is not a dictionary."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
rank_table_file_path=$(get_real_path_name $8)
|
||||
echo "rank_table_file_path: "$rank_table_file_path
|
||||
|
||||
ulimit -c unlimited
|
||||
export SLOG_PRINT_TO_STDOUT=0
|
||||
export RANK_TABLE_FILE=$rank_table_file_path
|
||||
export RANK_SIZE=8
|
||||
export RANK_START_ID=0
|
||||
|
||||
for((i=0;i<=$RANK_SIZE-1;i++));
|
||||
do
|
||||
export RANK_ID=${i}
|
||||
export DEVICE_ID=$((i + RANK_START_ID))
|
||||
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
|
||||
if [ -d ${train_code_path}device${DEVICE_ID} ]; then
|
||||
rm -rf ${train_code_path}device${DEVICE_ID}
|
||||
fi
|
||||
mkdir ${train_code_path}device${DEVICE_ID}
|
||||
cd ${train_code_path}device${DEVICE_ID} || exit
|
||||
nohup python ${train_code_path}train.py --is_distributed=1 \
|
||||
--train_data=${train_data} \
|
||||
--batch_size=$3 \
|
||||
--sigma=$4 \
|
||||
--channel=$5 \
|
||||
--epoch=$6 \
|
||||
--lr=$7 > ${train_code_path}device${DEVICE_ID}/log.txt 2>&1 &
|
||||
done
|
|
@ -0,0 +1,63 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 6 ]; then
|
||||
echo "Usage: sh run_eval.sh [train_code_path] [test_dir] [sigma]" \
|
||||
"[channel] [pretrain_path] [ckpt_name]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)/"
|
||||
fi
|
||||
}
|
||||
|
||||
train_code_path=$(get_real_path $1)
|
||||
echo "train_code_path: "$train_code_path
|
||||
|
||||
test_dir=$(get_real_path $2)
|
||||
echo "test_dir: "$test_dir
|
||||
|
||||
pretrain_path=$(get_real_path $5)
|
||||
echo "pretrain_path: "$pretrain_path
|
||||
|
||||
if [ ! -d $train_code_path ]
|
||||
then
|
||||
echo "error: train_code_path=$train_code_path is not a dictionary."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $test_dir ]
|
||||
then
|
||||
echo "error: test_dir=$test_dir is not a dictionary."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $pretrain_path ]
|
||||
then
|
||||
echo "error: pretrain_path=$pretrain_path is not a dictionary."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
python ${train_code_path}eval.py \
|
||||
--test_dir=$test_dir \
|
||||
--sigma=$3 \
|
||||
--channel=$4 \
|
||||
--pretrain_path=$pretrain_path \
|
||||
--ckpt_name=$6
|
|
@ -0,0 +1,131 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 6 ]; then
|
||||
echo "Usage: sh run_infer_310.sh [model_path] [data_path]" \
|
||||
"[noise_image_path] [sigma] [channel] [device_id]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path_name() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)/"
|
||||
fi
|
||||
}
|
||||
|
||||
model=$(get_real_path_name $1)
|
||||
data_path=$(get_real_path $2)
|
||||
noise_image_path=$(get_real_path $3)
|
||||
sigma=$4
|
||||
channel=$5
|
||||
device_id=$6
|
||||
|
||||
echo "model path: "$model
|
||||
echo "dataset path: "$data_path
|
||||
echo "noise image path: "$noise_image_path
|
||||
echo "sigma: "$sigma
|
||||
echo "channel: "$channel
|
||||
echo "device id: "$device_id
|
||||
|
||||
export ASCEND_HOME=/usr/local/Ascend/
|
||||
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||
else
|
||||
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||
fi
|
||||
|
||||
function compile_app()
|
||||
{
|
||||
cd ../ascend310_infer/ || exit
|
||||
if [ -f "Makefile" ]; then
|
||||
make clean
|
||||
fi
|
||||
sh build.sh &> build.log
|
||||
}
|
||||
|
||||
function preprocess()
|
||||
{
|
||||
cd ../ || exit
|
||||
echo "\nstart preprocess"
|
||||
echo "waitting for preprocess finish..."
|
||||
python3.7 preprocess.py --out_dir=$noise_image_path --image_path=$data_path --channel=$channel --sigma=$sigma > preprocess.log 2>&1
|
||||
echo "preprocess finished! you can see the log in preprocess.log!"
|
||||
}
|
||||
|
||||
function infer()
|
||||
{
|
||||
cd ./scripts || exit
|
||||
if [ -d result_Files ]; then
|
||||
rm -rf ./result_Files
|
||||
fi
|
||||
if [ -d time_Result ]; then
|
||||
rm -rf ./time_Result
|
||||
fi
|
||||
mkdir result_Files
|
||||
mkdir time_Result
|
||||
echo "\nstart infer..."
|
||||
echo "waitting for infer finish..."
|
||||
../ascend310_infer/out/main --model_path=$model --dataset_path=$noise_image_path --device_id=$device_id > infer.log 2>&1
|
||||
echo "infer finished! you can see the log in infer.log!"
|
||||
}
|
||||
|
||||
function cal_psnr()
|
||||
{
|
||||
echo "\nstart calculate PSNR..."
|
||||
echo "waitting for calculate finish..."
|
||||
python3.7 ../cal_psnr.py --image_path=$data_path --output_path=./result_Files --channel=$channel > psnr.log 2>&1
|
||||
echo "calculate finished! you can see the log in psnr.log\n"
|
||||
}
|
||||
|
||||
compile_app
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "compile app code failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
preprocess
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "execute preprocess failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
infer
|
||||
if [ $? -ne 0 ]; then
|
||||
echo " execute inference failed"
|
||||
exit 1
|
||||
fi
|
||||
cal_psnr
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "calculate psnr failed"
|
||||
exit 1
|
||||
fi
|
|
@ -0,0 +1,57 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 7 ]; then
|
||||
echo "Usage: sh run_train.sh [train_code_path] [train_data]" \
|
||||
"[batch_size] [sigma] [channel] [epoch] [lr]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)/"
|
||||
fi
|
||||
}
|
||||
|
||||
train_code_path=$(get_real_path $1)
|
||||
echo "train_code_path: "$train_code_path
|
||||
|
||||
train_data=$(get_real_path $2)
|
||||
echo "train_data: "$train_data
|
||||
|
||||
if [ ! -d $train_code_path ]
|
||||
then
|
||||
echo "error: train_code_path=$train_code_path is not a dictionary."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $train_data ]
|
||||
then
|
||||
echo "error: train_data=$train_data is not a dictionary."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
nohup python ${train_code_path}train.py --is_distributed=0 \
|
||||
--train_data=${train_data} \
|
||||
--batch_size=$3 \
|
||||
--sigma=$4 \
|
||||
--channel=$5 \
|
||||
--epoch=$6 \
|
||||
--lr=$7 > log.txt 2>&1 &
|
||||
echo 'Train task has been started successfully!'
|
||||
echo 'Please check the log at log.txt'
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import glob
|
||||
import numpy as np
|
||||
import PIL.Image as Image
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as CV
|
||||
from src.distributed_sampler import DistributedSampler
|
||||
|
||||
class BRDNetDataset:
|
||||
""" BRDNetDataset.
|
||||
Args:
|
||||
data_path: path of images, must end by '/'
|
||||
sigma: noise level
|
||||
channel: 3 for color, 1 for gray
|
||||
"""
|
||||
def __init__(self, data_path, sigma, channel):
|
||||
images = []
|
||||
file_dictory = glob.glob(data_path+'*.bmp') #notice the data format
|
||||
for file in file_dictory:
|
||||
images.append(file)
|
||||
self.images = images
|
||||
self.sigma = sigma
|
||||
self.channel = channel
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
get_batch_x: image with noise
|
||||
get_batch_y: image without noise
|
||||
"""
|
||||
if self.channel == 3:
|
||||
get_batch_y = np.array(Image.open(self.images[index]), dtype='uint8')
|
||||
else:
|
||||
get_batch_y = np.expand_dims(np.array(Image.open(self.images[index]).convert('L'), dtype='uint8'), axis=2)
|
||||
|
||||
get_batch_y = get_batch_y.astype('float32')/255.0
|
||||
noise = np.random.normal(0, self.sigma/255.0, get_batch_y.shape).astype('float32') # noise
|
||||
get_batch_x = get_batch_y + noise # input image = clean image + noise
|
||||
return get_batch_x, get_batch_y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def create_BRDNetDataset(data_path, sigma, channel, batch_size, device_num, rank, shuffle):
|
||||
|
||||
dataset = BRDNetDataset(data_path, sigma, channel)
|
||||
dataset_len = len(dataset)
|
||||
distributed_sampler = DistributedSampler(dataset_len, device_num, rank, shuffle=shuffle)
|
||||
hwc_to_chw = CV.HWC2CHW()
|
||||
data_set = ds.GeneratorDataset(dataset, column_names=["image", "label"], \
|
||||
shuffle=shuffle, sampler=distributed_sampler)
|
||||
data_set = data_set.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8)
|
||||
data_set = data_set.map(input_columns=["label"], operations=hwc_to_chw, num_parallel_workers=8)
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
return data_set, dataset_len
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""BRDNetDataset distributed sampler."""
|
||||
from __future__ import division
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
"""Distributed sampler."""
|
||||
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
|
||||
if num_replicas is None:
|
||||
print("***********Setting world_size to 1 since it is not passed in ******************")
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
print("***********Setting rank to 0 since it is not passed in ******************")
|
||||
rank = 0
|
||||
self.dataset_size = dataset_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
if self.shuffle:
|
||||
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
|
||||
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
|
||||
indices = indices.tolist()
|
||||
self.epoch += 1
|
||||
# change to list type
|
||||
else:
|
||||
indices = list(range(self.dataset_size))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Custom Logger."""
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
"""
|
||||
Logger.
|
||||
|
||||
Args:
|
||||
logger_name: String. Logger name.
|
||||
rank: Integer. Rank id.
|
||||
"""
|
||||
def __init__(self, logger_name, rank=0):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
self.rank = rank
|
||||
if rank % 8 == 0:
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
|
||||
def setup_logging_file(self, log_dir, rank=0):
|
||||
"""Setup logging file."""
|
||||
self.rank = rank
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO):
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, logger_name, rank):
|
||||
"""Get Logger."""
|
||||
logger = LOGGER(logger_name, rank)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
|
@ -0,0 +1,150 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
#from batch_renorm import BatchRenormalization
|
||||
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
class BRDNet(nn.Cell):
|
||||
"""
|
||||
args:
|
||||
channel: 3 for color, 1 for gray
|
||||
"""
|
||||
def __init__(self, channel):
|
||||
super(BRDNet, self).__init__()
|
||||
|
||||
self.Conv2d_1 = nn.Conv2d(channel, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)
|
||||
self.BRN_1 = nn.BatchNorm2d(64, eps=1e-3)
|
||||
self.layer1 = self.make_layer1(15)
|
||||
self.Conv2d_2 = nn.Conv2d(64, channel, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)
|
||||
self.Conv2d_3 = nn.Conv2d(channel, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)
|
||||
self.BRN_2 = nn.BatchNorm2d(64, eps=1e-3)
|
||||
self.layer2 = self.make_layer2(7)
|
||||
self.Conv2d_4 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)
|
||||
self.BRN_3 = nn.BatchNorm2d(64, eps=1e-3)
|
||||
self.layer3 = self.make_layer2(6)
|
||||
self.Conv2d_5 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)
|
||||
self.BRN_4 = nn.BatchNorm2d(64, eps=1e-3)
|
||||
self.Conv2d_6 = nn.Conv2d(64, channel, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)
|
||||
self.Conv2d_7 = nn.Conv2d(channel*2, channel, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.sub = ops.Sub()
|
||||
self.concat = ops.Concat(axis=1)#NCHW
|
||||
|
||||
def make_layer1(self, nums):
|
||||
layers = []
|
||||
assert nums > 0
|
||||
for _ in range(nums):
|
||||
layers.append(nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True))
|
||||
layers.append(nn.BatchNorm2d(64, eps=1e-3))
|
||||
layers.append(nn.ReLU())
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def make_layer2(self, nums):
|
||||
layers = []
|
||||
assert nums > 0
|
||||
for _ in range(nums):
|
||||
layers.append(nn.Conv2d(64, 64, kernel_size=(3, 3), \
|
||||
stride=(1, 1), dilation=(2, 2), pad_mode='same', has_bias=True))
|
||||
layers.append(nn.ReLU())
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, inpt):
|
||||
#inpt-----> 'NCHW'
|
||||
x = self.Conv2d_1(inpt)
|
||||
x = self.BRN_1(x)
|
||||
x = self.relu(x)
|
||||
# 15 layers, Conv+BN+relu
|
||||
x = self.layer1(x)
|
||||
|
||||
# last layer, Conv
|
||||
x = self.Conv2d_2(x) #for output channel, gray is 1 color is 3
|
||||
x = self.sub(inpt, x) # input - noise
|
||||
|
||||
y = self.Conv2d_3(inpt)
|
||||
y = self.BRN_2(y)
|
||||
y = self.relu(y)
|
||||
|
||||
# first Conv+relu's
|
||||
y = self.layer2(y)
|
||||
|
||||
y = self.Conv2d_4(y)
|
||||
y = self.BRN_3(y)
|
||||
y = self.relu(y)
|
||||
|
||||
# second Conv+relu's
|
||||
y = self.layer3(y)
|
||||
|
||||
y = self.Conv2d_5(y)
|
||||
y = self.BRN_4(y)
|
||||
y = self.relu(y)
|
||||
|
||||
y = self.Conv2d_6(y)#for output channel, gray is 1 color is 3
|
||||
y = self.sub(inpt, y) # input - noise
|
||||
|
||||
o = self.concat((x, y))
|
||||
z = self.Conv2d_7(o)#gray is 1 color is 3
|
||||
z = self.sub(inpt, z)
|
||||
|
||||
return z
|
||||
|
||||
class BRDWithLossCell(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(BRDWithLossCell, self).__init__()
|
||||
self.network = network
|
||||
self.loss = nn.MSELoss(reduction='sum') #we use 'sum' instead of 'mean' to avoid the loss becoming too small
|
||||
def construct(self, images, targets):
|
||||
output = self.network(images)
|
||||
return self.loss(output, targets)
|
||||
|
||||
class TrainingWrapper(nn.Cell):
|
||||
"""Training wrapper."""
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainingWrapper, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
if auto_parallel_context().get_device_num_is_set():
|
||||
degree = context.get_auto_parallel_context("device_num")
|
||||
else:
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, *args):
|
||||
weights = self.weights
|
||||
loss = self.network(*args)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(*args, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
|
@ -0,0 +1,182 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import datetime
|
||||
import argparse
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train import Model
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.callback import TimeMonitor, LossMonitor
|
||||
from mindspore import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
|
||||
from src.logger import get_logger
|
||||
from src.dataset import create_BRDNetDataset
|
||||
from src.models import BRDNet, BRDWithLossCell, TrainingWrapper
|
||||
|
||||
|
||||
## Params
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--batch_size', default=32, type=int, help='batch size')
|
||||
parser.add_argument('--train_data', default='../dataset/waterloo5050step40colorimage/'
|
||||
, type=str, help='path of train data')
|
||||
parser.add_argument('--sigma', default=75, type=int, help='noise level')
|
||||
parser.add_argument('--channel', default=3, type=int
|
||||
, help='image channel, 3 for color, 1 for gray')
|
||||
parser.add_argument('--epoch', default=50, type=int, help='number of train epoches')
|
||||
parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate for Adam')
|
||||
parser.add_argument('--save_every', default=1, type=int, help='save model at every x epoches')
|
||||
parser.add_argument('--pretrain', default=None, type=str, help='path of pre-trained model')
|
||||
parser.add_argument('--use_modelarts', type=int, default=0
|
||||
, help='1 for True, 0 for False; when set True, we should load dataset from obs with moxing')
|
||||
parser.add_argument('--train_url', type=str, default='train_url/'
|
||||
, help='needed by modelarts, but we donot use it because the name is ambiguous')
|
||||
parser.add_argument('--data_url', type=str, default='data_url/'
|
||||
, help='needed by modelarts, but we donot use it because the name is ambiguous')
|
||||
parser.add_argument('--output_path', type=str, default='./output/'
|
||||
, help='output_path,when use_modelarts is set True, it will be cache/output/')
|
||||
parser.add_argument('--outer_path', type=str, default='s3://output/'
|
||||
, help='obs path,to store e.g ckpt files ')
|
||||
|
||||
parser.add_argument('--device_target', type=str, default='Ascend'
|
||||
, help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
||||
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
||||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
|
||||
parser.add_argument('--ckpt_save_max', type=int, default=5
|
||||
, help='Maximum number of checkpoint files can be saved. Default: 5.')
|
||||
|
||||
set_seed(1)
|
||||
args = parser.parse_args()
|
||||
save_dir = os.path.join(args.output_path, 'sigma_' + str(args.sigma) \
|
||||
+ '_' + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
|
||||
if not args.use_modelarts and not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
def get_lr(steps_per_epoch, max_epoch, init_lr):
|
||||
lr_each_step = []
|
||||
while max_epoch > 0:
|
||||
tem = min(30, max_epoch)
|
||||
for _ in range(steps_per_epoch*tem):
|
||||
lr_each_step.append(init_lr)
|
||||
max_epoch -= tem
|
||||
init_lr /= 10
|
||||
return lr_each_step
|
||||
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.device_target, save_graphs=False)
|
||||
|
||||
def train():
|
||||
|
||||
if args.is_distributed:
|
||||
assert args.device_target == "Ascend"
|
||||
init()
|
||||
context.set_context(device_id=device_id)
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
device_num = args.group_size
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||
else:
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
# select for master rank save ckpt or all rank save, compatible for model parallel
|
||||
args.rank_save_ckpt_flag = 0
|
||||
if args.is_save_on_master:
|
||||
if args.rank == 0:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
else:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
|
||||
|
||||
args.logger = get_logger(save_dir, "BRDNet", args.rank)
|
||||
args.logger.save_args(args)
|
||||
|
||||
if args.use_modelarts:
|
||||
import moxing as mox
|
||||
args.logger.info("copying train data from obs to cache....")
|
||||
mox.file.copy_parallel(args.train_data, 'cache/dataset') # the args.train_data must end by '/'
|
||||
args.logger.info("copying traindata finished....")
|
||||
args.train_data = 'cache/dataset/' # the args.train_data must end by '/'
|
||||
|
||||
dataset, batch_num = create_BRDNetDataset(args.train_data, args.sigma, \
|
||||
args.channel, args.batch_size, args.group_size, args.rank, shuffle=True)
|
||||
|
||||
args.steps_per_epoch = int(batch_num / args.batch_size / args.group_size)
|
||||
|
||||
model = BRDNet(args.channel)
|
||||
|
||||
if args.pretrain:
|
||||
if args.use_modelarts:
|
||||
import moxing as mox
|
||||
args.logger.info("copying pretrain model from obs to cache....")
|
||||
mox.file.copy_parallel(args.pretrain, 'cache/pretrain')
|
||||
args.logger.info("copying pretrain model finished....")
|
||||
args.pretrain = 'cache/pretrain/'+args.pretrain.split('/')[-1]
|
||||
|
||||
args.logger.info('loading pre-trained model {} into network'.format(args.pretrain))
|
||||
load_param_into_net(model, load_checkpoint(args.pretrain))
|
||||
args.logger.info('loaded pre-trained model {} into network'.format(args.pretrain))
|
||||
|
||||
|
||||
model = BRDWithLossCell(model)
|
||||
model.set_train()
|
||||
|
||||
lr_list = get_lr(args.steps_per_epoch, args.epoch, args.lr)
|
||||
optimizer = nn.Adam(params=model.trainable_params(), learning_rate=Tensor(lr_list, mindspore.float32))
|
||||
model = TrainingWrapper(model, optimizer)
|
||||
|
||||
model = Model(model)
|
||||
|
||||
# define callbacks
|
||||
if args.rank == 0:
|
||||
time_cb = TimeMonitor(data_size=batch_num)
|
||||
loss_cb = LossMonitor(per_print_times=10)
|
||||
callbacks = [time_cb, loss_cb]
|
||||
else:
|
||||
callbacks = None
|
||||
if args.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch*args.save_every,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
save_ckpt_path = os.path.join(save_dir, 'ckpt_' + str(args.rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='channel_'+str(args.channel)+'_sigma_'+str(args.sigma)+'_rank_'+str(args.rank))
|
||||
callbacks.append(ckpt_cb)
|
||||
|
||||
model.train(args.epoch, dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
|
||||
args.logger.info("training finished....")
|
||||
if args.use_modelarts:
|
||||
args.logger.info("copying files from cache to obs....")
|
||||
mox.file.copy_parallel(save_dir, args.outer_path)
|
||||
args.logger.info("copying finished....")
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
train()
|
||||
args.logger.info('All task finished!')
|
Loading…
Reference in New Issue