!17158 新增图像去噪网络 BRDNet 到 master分支

Merge pull request !17158 from 谭华林/master
This commit is contained in:
i-robot 2021-09-15 08:13:47 +00:00 committed by Gitee
commit 8720c66af8
19 changed files with 2138 additions and 0 deletions

View File

@ -0,0 +1,532 @@
# BRDNet
<!-- TOC -->
- [BRDNet](#BRDNet)
- [BRDNet介绍](#BRDNet介绍)
- [模型结构](#模型结构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [训练](#训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [310推理](#310推理)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
## BRDNet介绍
BRDNet 于 2020 年发表在人工智能期刊 Neural Networks并被该期刊评为 2019/2020 年最高下载量之一的论文。该文提出了在硬件资源受限条件处理数据分步不均匀问题的方案,同时首次提出利用双路网络提取互补信息的思路来进行图像去噪,能有效去除合成噪声和真实噪声。
[论文](https://www.sciencedirect.com/science/article/pii/S0893608019302394)Ct A , Yong X , Wz C . Image denoising using deep CNN with batch renormalization[J]. Neural Networks, 2020, 121:461-473.
## 模型结构
BRDNet 包含上下两个分支。上分支仅仅包含残差学习与 BRN下分支包含 BRN、残差学习以及扩张卷积。
上分支网络包含两种不同类型的层Conv+BRN+ReLU 与 Conv。它的深度为 17前 16 个层为 Conv+BRN+ReLU, 最后一层为 Conv。特征图通道数为 64卷积核尺寸为 3.
下分支网络同样包含 17 层,但其 1、9、16 层为 Conv+BRN+ReLU2-8、10-15 为扩张卷积,最后一层为 Conv。卷积核尺寸为 3通道数为 64.
两个分支结果经由 `Concat` 组合并经 Conv 得到噪声,最后采用原始输入减去该噪声即可得到清晰无噪图像。整个 BRDNet 合计包含 18 层,相对比较浅,不会导致梯度消失与梯度爆炸问题。
## 数据集
训练数据集:[color noisy](<https://pan.baidu.com/s/1cx3ymsWLIT-YIiJRBza24Q>)
去除高斯噪声时,数据集选用源于 Waterloo Exploration Database 的 3,859 张图片,并预处理为 50x50 的小图片,共计 1,166,393 张。
测试数据集CBSD68, Kodak24, and McMaster
## 环境要求
- 硬件Ascend/ModelArts
- 准备Ascend或ModelArts处理器搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.html)
## 快速入门
通过官方网站安装 MindSpore 后,您可以按照如下步骤进行训练和评估:
```shell
#通过 python 命令行运行单卡训练脚本。 请注意 train_data 需要以"/"结尾
python train.py \
--train_data=xxx/dataset/waterloo5050step40colorimage/ \
--sigma=15 \
--channel=3 \
--batch_size=32 \
--lr=0.001 \
--use_modelarts=0 \
--output_path=./output/ \
--is_distributed=0 \
--epoch=50 > log.txt 2>&1 &
#通过 sh 命令启动单卡训练。(对 train_data 等参数的路径格式无要求,内部会自动转为绝对路径以及以"/"结尾)
sh ./scripts/run_train.sh [train_code_path] [train_data] [batch_size] [sigma] [channel] [epoch] [lr]
#Ascend多卡训练
sh run_distribute_train.sh [train_code_path] [train_data] [batch_size] [sigma] [channel] [epoch] [lr] [rank_table_file_path]
# 通过 python 命令行运行推理脚本。请注意 test_dir 需要以"/"结尾;
# pretrain_path 指 ckpt 所在目录,为了兼容 modelarts将其拆分为了 “路径” 与 “文件名”
python eval.py \
--test_dir=xxx/Test/Kodak24/ \
--sigma=15 \
--channel=3 \
--pretrain_path=xxx/ \
--ckpt_name=channel_3_sigma_15_rank_0-50_227800.ckpt \
--use_modelarts=0 \
--output_path=./output/ \
--is_distributed=0 > log.txt 2>&1 &
#通过 sh 命令启动推理。(对 test_dir 等参数的路径格式无要求,内部会自动转为绝对路径以及以"/"结尾)
sh run_eval.sh [train_code_path] [test_dir] [sigma] [channel] [pretrain_path] [ckpt_name]
```
Ascend训练生成[RANK_TABLE_FILE](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)
## 脚本说明
### 脚本及样例代码
```shell
├── model_zoo
├── README.md // 所有模型的说明文件
├── brdnet
├── README.md // brdnet 的说明文件
├── ascend310_infer // 310 推理代码目录C++
│ ├──inc
│ │ ├──utils.h // 工具包头文件
│ ├──src
│ │ ├──main.cc // 310推理代码
│ │ ├──utils.cc // 工具包
│ ├──build.sh // 代码编译脚本
│ ├──CMakeLists.txt // 代码编译设置
├── scripts
│ ├──run_distribute_train.sh // Ascend 8卡训练脚本
│ ├──run_eval.sh // 推理启动脚本
│ ├──run_train.sh // 训练启动脚本
│ ├──run_infer_310.sh // 启动310推理的脚本
├── src
│ ├──dataset.py // 数据集处理
│ ├──distributed_sampler.py // 8卡并行时的数据集切分操作
│ ├──logger.py // 日志打印文件
│ ├──models.py // 模型结构
├── export.py // 将权重文件导出为 MINDIR 等格式的脚本
├── train.py // 训练脚本
├── eval.py // 推理脚本
├── cal_psnr.py // 310推理时计算最终PSNR值的脚本
├── preprocess.py // 310推理时为测试图片添加噪声的脚本
```
### 脚本参数
```shell
train.py 中的主要参数如下:
--batch_size: 批次大小
--train_data: 训练数据集路径(必须以"/"结尾)。
--sigma: 高斯噪声强度
--channel: 训练类型3彩色图1灰度图
--epoch: 训练次数
--lr: 初始学习率
--save_every: 权重保存频率(每 N 个 epoch 保存一次)
--pretrain: 预训练文件(接着该文件继续训练)
--use_modelarts: 是否使用 modelarts1 for True, 0 for False; 设置为 1 时将使用 moxing 从 obs 拷贝数据)
--train_url: modelsarts 需要的参数,但因该名称存在歧义而在代码中未使用)
--data_url: modelsarts 需要的参数,但因该名称存在歧义而在代码中未使用)
--output_path: 日志等文件输出目录
--outer_path: 输出到 obs 外部的目录(仅在 modelarts 上运行时有效)
--device_target: 运行设备(默认 "Ascend"
--is_distributed: 是否多卡运行
--rank: Local rank of distributed. Default: 0
--group_size: World size of device. Default: 1
--is_save_on_master: 是否仅保存 0 卡上的运行结果
--ckpt_save_max: ckpt 保存的最多文件数
eval.py 中的主要参数如下:
--test_dir: 测试数据集路径(必须以"/"结尾)。
--sigma: 高斯噪声强度
--channel: 推理类型3彩色图1灰度图
--pretrain_path: 权重文件路径
--ckpt_name: 权重文件名称
--use_modelarts: 是否使用 modelarts1 for True, 0 for False; 设置为 1 时将使用 moxing 从 obs 拷贝数据)
--train_url: modelsarts 需要的参数,但因该名称存在歧义而在代码中未使用)
--data_url: modelsarts 需要的参数,但因该名称存在歧义而在代码中未使用)
--output_path: 日志等文件输出目录
--outer_path: 输出到 obs 外部的目录(仅在 modelarts 上运行时有效)
--device_target: 运行设备(默认 "Ascend"
export.py 中的主要参数如下:
--batch_size: 批次大小
--channel: 训练类型3彩色图1灰度图
--image_height: 图片高度
--image_width: 图片宽度
--ckpt_file: 权重文件路径
--file_name: 权重文件名称
--file_format: 待转文件格式choices=["AIR", "ONNX", "MINDIR"]
--device_target: 运行设备(默认 "Ascend"
--device_id: 运行设备id
preprocess.py 中的主要参数如下:
--out_dir: 保存噪声图片的路径
--image_path: 测试图片路径(必须以"/"结尾)
--channel: 图片通道3彩色图1灰度图
--sigma: 噪声强度
cal_psnr.py 中的主要参数如下:
--image_path: 测试图片路径(必须以"/"结尾)
--output_path: 模型推理完成后的结果保存路径,默认为 xx/scripts/result_Files/
--image_width: 图片宽度
--image_height: 图片高度
--channel:图片通道3彩色图1灰度图
```
### 训练过程
#### 训练
- Ascend处理器环境运行
```shell
#通过 python 命令行运行单卡训练脚本。 请注意 train_data 需要以"/"结尾
python train.py \
--train_data=xxx/dataset/waterloo5050step40colorimage/ \
--sigma=15 \
--channel=3 \
--batch_size=32 \
--lr=0.001 \
--use_modelarts=0 \
--output_path=./output/ \
--is_distributed=0 \
--epoch=50 > log.txt 2>&1 &
#通过 sh 命令启动单卡训练。(对 train_data 等参数的路径格式无要求,内部会自动转为绝对路径以及以"/"结尾)
sh ./scripts/run_train.sh [train_code_path] [train_data] [batch_size] [sigma] [channel] [epoch] [lr]
#上述命令均会使脚本在后台运行,日志将输出到 log.txt可通过查看该文件了解训练详情
#Ascend多卡训练(2、4、8卡配置请自行修改run_distribute_train.sh默认8卡)
sh run_distribute_train.sh [train_code_path] [train_data] [batch_size] [sigma] [channel] [epoch] [lr] [rank_table_file_path]
```
注意:第一次运行时可能会较长时间停留在如下界面,这是因为当一个 epoch 运行完成后才会打印日志,请耐心等待。
单卡运行时第一个 epoch 预计耗时 20 ~ 30 分钟。
```python
2021-05-16 20:12:17,888:INFO:Args:
2021-05-16 20:12:17,888:INFO:--> batch_size: 32
2021-05-16 20:12:17,888:INFO:--> train_data: ../dataset/waterloo5050step40colorimage/
2021-05-16 20:12:17,889:INFO:--> sigma: 15
2021-05-16 20:12:17,889:INFO:--> channel: 3
2021-05-16 20:12:17,889:INFO:--> epoch: 50
2021-05-16 20:12:17,889:INFO:--> lr: 0.001
2021-05-16 20:12:17,889:INFO:--> save_every: 1
2021-05-16 20:12:17,889:INFO:--> pretrain: None
2021-05-16 20:12:17,889:INFO:--> use_modelarts: False
2021-05-16 20:12:17,889:INFO:--> train_url: train_url/
2021-05-16 20:12:17,889:INFO:--> data_url: data_url/
2021-05-16 20:12:17,889:INFO:--> output_path: ./output/
2021-05-16 20:12:17,889:INFO:--> outer_path: s3://output/
2021-05-16 20:12:17,889:INFO:--> device_target: Ascend
2021-05-16 20:12:17,890:INFO:--> is_distributed: 0
2021-05-16 20:12:17,890:INFO:--> rank: 0
2021-05-16 20:12:17,890:INFO:--> group_size: 1
2021-05-16 20:12:17,890:INFO:--> is_save_on_master: 1
2021-05-16 20:12:17,890:INFO:--> ckpt_save_max: 5
2021-05-16 20:12:17,890:INFO:--> rank_save_ckpt_flag: 1
2021-05-16 20:12:17,890:INFO:--> logger: <LOGGER BRDNet (NOTSET)>
2021-05-16 20:12:17,890:INFO:
```
训练完成后,您可以在 --output_path 参数指定的目录下找到保存的权重文件,训练过程中的部分 loss 收敛情况如下:
```python
# grep "epoch time:" log.txt
epoch time: 1197471.061 ms, per step time: 32.853 ms
epoch time: 1136826.065 ms, per step time: 31.189 ms
epoch time: 1136840.334 ms, per step time: 31.190 ms
epoch time: 1136837.709 ms, per step time: 31.190 ms
epoch time: 1137081.757 ms, per step time: 31.197 ms
epoch time: 1136830.581 ms, per step time: 31.190 ms
epoch time: 1136845.253 ms, per step time: 31.190 ms
epoch time: 1136881.960 ms, per step time: 31.191 ms
epoch time: 1136850.673 ms, per step time: 31.190 ms
epoch: 10 step: 36449, loss is 103.104095
epoch time: 1137098.407 ms, per step time: 31.197 ms
epoch time: 1136794.613 ms, per step time: 31.189 ms
epoch time: 1136742.922 ms, per step time: 31.187 ms
epoch time: 1136842.009 ms, per step time: 31.190 ms
epoch time: 1136792.705 ms, per step time: 31.189 ms
epoch time: 1137056.362 ms, per step time: 31.196 ms
epoch time: 1136863.373 ms, per step time: 31.191 ms
epoch time: 1136842.938 ms, per step time: 31.190 ms
epoch time: 1136839.011 ms, per step time: 31.190 ms
epoch time: 1136879.794 ms, per step time: 31.191 ms
epoch: 20 step: 36449, loss is 61.104546
epoch time: 1137035.395 ms, per step time: 31.195 ms
epoch time: 1136830.626 ms, per step time: 31.190 ms
epoch time: 1136862.117 ms, per step time: 31.190 ms
epoch time: 1136812.265 ms, per step time: 31.189 ms
epoch time: 1136821.096 ms, per step time: 31.189 ms
epoch time: 1137050.310 ms, per step time: 31.196 ms
epoch time: 1136815.292 ms, per step time: 31.189 ms
epoch time: 1136817.757 ms, per step time: 31.189 ms
epoch time: 1136876.477 ms, per step time: 31.191 ms
epoch time: 1136798.538 ms, per step time: 31.189 ms
epoch: 30 step: 36449, loss is 116.179596
epoch time: 1136972.930 ms, per step time: 31.194 ms
epoch time: 1136825.174 ms, per step time: 31.189 ms
epoch time: 1136798.900 ms, per step time: 31.189 ms
epoch time: 1136828.101 ms, per step time: 31.190 ms
epoch time: 1136862.983 ms, per step time: 31.191 ms
epoch time: 1136989.445 ms, per step time: 31.194 ms
epoch time: 1136688.820 ms, per step time: 31.186 ms
epoch time: 1136858.111 ms, per step time: 31.190 ms
epoch time: 1136822.853 ms, per step time: 31.189 ms
epoch time: 1136782.455 ms, per step time: 31.188 ms
epoch: 40 step: 36449, loss is 70.95368
epoch time: 1137042.689 ms, per step time: 31.195 ms
epoch time: 1136797.706 ms, per step time: 31.189 ms
epoch time: 1136817.007 ms, per step time: 31.189 ms
epoch time: 1136861.577 ms, per step time: 31.190 ms
epoch time: 1136698.149 ms, per step time: 31.186 ms
epoch time: 1137052.034 ms, per step time: 31.196 ms
epoch time: 1136809.339 ms, per step time: 31.189 ms
epoch time: 1136851.343 ms, per step time: 31.190 ms
epoch time: 1136761.354 ms, per step time: 31.188 ms
epoch time: 1136837.762 ms, per step time: 31.190 ms
epoch: 50 step: 36449, loss is 87.13184
epoch time: 1137022.554 ms, per step time: 31.195 ms
2021-05-19 14:24:52,695:INFO:training finished....
...
```
8 卡并行时的 loss 收敛情况:
```python
epoch time: 217708.130 ms, per step time: 47.785 ms
epoch time: 144899.598 ms, per step time: 31.804 ms
epoch time: 144736.054 ms, per step time: 31.768 ms
epoch time: 144737.085 ms, per step time: 31.768 ms
epoch time: 144738.102 ms, per step time: 31.769 ms
epoch: 5 step: 4556, loss is 106.67432
epoch time: 144905.830 ms, per step time: 31.805 ms
epoch time: 144736.539 ms, per step time: 31.768 ms
epoch time: 144734.210 ms, per step time: 31.768 ms
epoch time: 144734.415 ms, per step time: 31.768 ms
epoch time: 144736.405 ms, per step time: 31.768 ms
epoch: 10 step: 4556, loss is 94.092865
epoch time: 144921.081 ms, per step time: 31.809 ms
epoch time: 144735.718 ms, per step time: 31.768 ms
epoch time: 144737.036 ms, per step time: 31.768 ms
epoch time: 144737.733 ms, per step time: 31.769 ms
epoch time: 144738.251 ms, per step time: 31.769 ms
epoch: 15 step: 4556, loss is 99.18075
epoch time: 144921.945 ms, per step time: 31.809 ms
epoch time: 144734.948 ms, per step time: 31.768 ms
epoch time: 144735.662 ms, per step time: 31.768 ms
epoch time: 144733.871 ms, per step time: 31.768 ms
epoch time: 144734.722 ms, per step time: 31.768 ms
epoch: 20 step: 4556, loss is 92.54497
epoch time: 144907.430 ms, per step time: 31.806 ms
epoch time: 144735.713 ms, per step time: 31.768 ms
epoch time: 144733.781 ms, per step time: 31.768 ms
epoch time: 144736.005 ms, per step time: 31.768 ms
epoch time: 144734.331 ms, per step time: 31.768 ms
epoch: 25 step: 4556, loss is 90.98991
epoch time: 144911.420 ms, per step time: 31.807 ms
epoch time: 144734.535 ms, per step time: 31.768 ms
epoch time: 144734.851 ms, per step time: 31.768 ms
epoch time: 144736.346 ms, per step time: 31.768 ms
epoch time: 144734.939 ms, per step time: 31.768 ms
epoch: 30 step: 4556, loss is 114.33954
epoch time: 144915.434 ms, per step time: 31.808 ms
epoch time: 144737.336 ms, per step time: 31.769 ms
epoch time: 144733.943 ms, per step time: 31.768 ms
epoch time: 144734.587 ms, per step time: 31.768 ms
epoch time: 144735.043 ms, per step time: 31.768 ms
epoch: 35 step: 4556, loss is 97.21166
epoch time: 144912.719 ms, per step time: 31.807 ms
epoch time: 144734.795 ms, per step time: 31.768 ms
epoch time: 144733.824 ms, per step time: 31.768 ms
epoch time: 144735.946 ms, per step time: 31.768 ms
epoch time: 144734.930 ms, per step time: 31.768 ms
epoch: 40 step: 4556, loss is 82.41978
epoch time: 144901.017 ms, per step time: 31.804 ms
epoch time: 144735.060 ms, per step time: 31.768 ms
epoch time: 144733.657 ms, per step time: 31.768 ms
epoch time: 144732.592 ms, per step time: 31.767 ms
epoch time: 144731.292 ms, per step time: 31.767 ms
epoch: 45 step: 4556, loss is 77.92129
epoch time: 144909.250 ms, per step time: 31.806 ms
epoch time: 144732.944 ms, per step time: 31.768 ms
epoch time: 144733.161 ms, per step time: 31.768 ms
epoch time: 144732.912 ms, per step time: 31.768 ms
epoch time: 144733.709 ms, per step time: 31.768 ms
epoch: 50 step: 4556, loss is 85.499596
2021-05-19 02:44:44,219:INFO:training finished....
```
### 评估过程
#### 评估
在运行以下命令之前,请检查用于推理评估的权重文件路径是否正确。
- Ascend处理器环境运行
```python
#通过 python 命令启动评估,请注意 test_dir 需要以"/" 结尾
python eval.py \
--test_dir=./Test/CBSD68/ \
--sigma=15 \
--channel=3 \
--pretrain_path=./ \
--ckpt_name=channel_3_sigma_15_rank_0-50_227800.ckpt \
--use_modelarts=0 \
--output_path=./output/ \
--is_distributed=0 > log.txt 2>&1 &
#通过 sh 命令启动评估 (对 test_dir 等参数的路径格式无要求,内部会自动转为绝对路径以及以"/"结尾)
sh run_eval.sh [train_code_path] [test_dir] [sigma] [channel] [pretrain_path] [ckpt_name]
```
```python
2021-05-17 13:40:45,909:INFO:Start to test on ./Test/CBSD68/
2021-05-17 13:40:46,447:INFO:load test weights from channel_3_sigma_15_rank_0-50_227800.ckpt
2021-05-17 13:41:52,164:INFO:Before denoise: Average PSNR_b = 24.62, SSIM_b = 0.56;After denoise: Average PSNR = 34.05, SSIM = 0.94
2021-05-17 13:41:52,207:INFO:testing finished....
```
评估完成后,您可以在 --output_path 参数指定的目录下找到 加高斯噪声后的图片和经过模型去除高斯噪声后的图片,图片命名方式代表了处理结果。例如 00001_sigma15_psnr24.62.bmp 是加噪声后的图片(加噪声后 psnr=24.6200001_psnr31.18.bmp 是去噪声后的图片(去噪后 psnr=31.18)。
另外,该文件夹下的 metrics.csv 文件详细记录了对每张测试图片的处理结果如下所示psnr_b 是去噪前的 psnr 值psnr 是去噪后的psnr 值ssim 指标同理。
| | name | psnr_b | psnr | ssim_b | ssim |
| ---- | ------- | ----------- | ----------- | ----------- | ----------- |
| 0 | 1 | 24.61875916 | 31.17827606 | 0.716650724 | 0.910416007 |
| 1 | 2 | 24.61875916 | 35.12858963 | 0.457143694 | 0.995960176 |
| 2 | 3 | 24.61875916 | 34.90437698 | 0.465185702 | 0.935821533 |
| 3 | 4 | 24.61875916 | 35.59785461 | 0.49323535 | 0.941600204 |
| 4 | 5 | 24.61875916 | 32.9185257 | 0.605194688 | 0.958840668 |
| 5 | 6 | 24.61875916 | 37.29947662 | 0.368243992 | 0.962466478 |
| 6 | 7 | 24.61875916 | 33.59238052 | 0.622622728 | 0.930195987 |
| 7 | 8 | 24.61875916 | 31.76290894 | 0.680295587 | 0.918859363 |
| 8 | 9 | 24.61875916 | 34.13358688 | 0.55876708 | 0.939204693 |
| 9 | 10 | 24.61875916 | 34.49848557 | 0.503289104 | 0.928179622 |
| 10 | 11 | 24.61875916 | 34.38597107 | 0.656857133 | 0.961226702 |
| 11 | 12 | 24.61875916 | 32.75747299 | 0.627940595 | 0.910765707 |
| 12 | 13 | 24.61875916 | 34.52487564 | 0.54259634 | 0.936489582 |
| 13 | 14 | 24.61875916 | 35.40441132 | 0.44824928 | 0.93462956 |
| 14 | 15 | 24.61875916 | 32.72385788 | 0.61768961 | 0.91652298 |
| 15 | 16 | 24.61875916 | 33.59120178 | 0.703662276 | 0.948698342 |
| 16 | 17 | 24.61875916 | 36.85597229 | 0.365240872 | 0.940135658 |
| 17 | 18 | 24.61875916 | 37.23021317 | 0.366332233 | 0.953653395 |
| 18 | 19 | 24.61875916 | 33.49061584 | 0.546713233 | 0.928890586 |
| 19 | 20 | 24.61875916 | 33.98015213 | 0.463814735 | 0.938398063 |
| 20 | 21 | 24.61875916 | 32.15977859 | 0.714740098 | 0.945747674 |
| 21 | 22 | 24.61875916 | 32.39984512 | 0.716880679 | 0.930429876 |
| 22 | 23 | 24.61875916 | 34.22258759 | 0.569748521 | 0.945626318 |
| 23 | 24 | 24.61875916 | 33.974823 | 0.603115499 | 0.941333234 |
| 24 | 25 | 24.61875916 | 34.87198639 | 0.486003697 | 0.966141582 |
| 25 | 26 | 24.61875916 | 33.2747879 | 0.593207896 | 0.917522907 |
| 26 | 27 | 24.61875916 | 34.67901611 | 0.504613101 | 0.921615481 |
| 27 | 28 | 24.61875916 | 37.70562363 | 0.331322074 | 0.977765024 |
| 28 | 29 | 24.61875916 | 31.08887672 | 0.759773433 | 0.958483219 |
| 29 | 30 | 24.61875916 | 34.48878479 | 0.502000451 | 0.915705442 |
| 30 | 31 | 24.61875916 | 30.5480938 | 0.836367846 | 0.949165165 |
| 31 | 32 | 24.61875916 | 32.08041382 | 0.745214283 | 0.941719413 |
| 32 | 33 | 24.61875916 | 33.65553284 | 0.556162357 | 0.963605523 |
| 33 | 34 | 24.61875916 | 36.87154388 | 0.384932011 | 0.93150568 |
| 34 | 35 | 24.61875916 | 33.03263474 | 0.586027861 | 0.924151421 |
| 35 | 36 | 24.61875916 | 31.80633736 | 0.572878599 | 0.84426564 |
| 36 | 37 | 24.61875916 | 33.26797485 | 0.526310742 | 0.938789487 |
| 37 | 38 | 24.61875916 | 33.71062469 | 0.554955184 | 0.914420724 |
| 38 | 39 | 24.61875916 | 37.3455925 | 0.461908668 | 0.956513464 |
| 39 | 40 | 24.61875916 | 33.92232895 | 0.554454744 | 0.89727515 |
| 40 | 41 | 24.61875916 | 33.05244827 | 0.590977669 | 0.931611121 |
| 41 | 42 | 24.61875916 | 34.60203552 | 0.492371827 | 0.927084684 |
| 42 | 43 | 24.61875916 | 35.20042419 | 0.535991669 | 0.949365258 |
| 43 | 44 | 24.61875916 | 33.47367096 | 0.614959836 | 0.954348624 |
| 44 | 45 | 24.61875916 | 37.65309143 | 0.363631308 | 0.944297135 |
| 45 | 46 | 24.61875916 | 31.95152092 | 0.709732175 | 0.924522877 |
| 46 | 47 | 24.61875916 | 31.9910202 | 0.70427531 | 0.932488263 |
| 47 | 48 | 24.61875916 | 34.96608353 | 0.585813344 | 0.969006479 |
| 48 | 49 | 24.61875916 | 35.39241409 | 0.388898522 | 0.923762918 |
| 49 | 50 | 24.61875916 | 32.11050415 | 0.653521299 | 0.938310325 |
| 50 | 51 | 24.61875916 | 33.54981995 | 0.594990134 | 0.927819192 |
| 51 | 52 | 24.61875916 | 35.79096603 | 0.371685684 | 0.922166049 |
| 52 | 53 | 24.61875916 | 35.10015869 | 0.410564244 | 0.895557165 |
| 53 | 54 | 24.61875916 | 34.12319565 | 0.591762364 | 0.925524533 |
| 54 | 55 | 24.61875916 | 32.79537964 | 0.653338313 | 0.92444253 |
| 55 | 56 | 24.61875916 | 29.90909004 | 0.826190114 | 0.943322361 |
| 56 | 57 | 24.61875916 | 33.23035812 | 0.527200282 | 0.943572938 |
| 57 | 58 | 24.61875916 | 34.56663132 | 0.409658968 | 0.898451686 |
| 58 | 59 | 24.61875916 | 34.43690109 | 0.454208463 | 0.904649734 |
| 59 | 60 | 24.61875916 | 35.0402565 | 0.409306735 | 0.902388573 |
| 60 | 61 | 24.61875916 | 34.91940308 | 0.443635762 | 0.911728501 |
| 61 | 62 | 24.61875916 | 38.25325394 | 0.42568323 | 0.965163887 |
| 62 | 63 | 24.61875916 | 32.07671356 | 0.727443576 | 0.94306612 |
| 63 | 64 | 24.61875916 | 31.72690964 | 0.671929657 | 0.902075231 |
| 64 | 65 | 24.61875916 | 33.47768402 | 0.533677042 | 0.922906399 |
| 65 | 66 | 24.61875916 | 42.14694977 | 0.266868591 | 0.991976082 |
| 66 | 67 | 24.61875916 | 30.81770706 | 0.84768647 | 0.957114518 |
| 67 | 68 | 24.61875916 | 32.24455261 | 0.623004258 | 0.97843051 |
| 68 | Average | 24.61875916 | 34.05390495 | 0.555872787 | 0.935704286 |
### 310推理
- 在 Ascend 310 处理器环境运行
```python
#通过 sh 命令启动推理
sh run_infer_310.sh [model_path] [data_path] [noise_image_path] [sigma] [channel] [device_id]
#上述命令将完成推理所需的全部工作。执行完成后,将产生 preprocess.log、infer.log、psnr.log 三个日志文件。
#如果您需要单独执行各部分代码,可以参照 run_infer_310.sh 内的流程分别进行编译、图片预处理、推理和 PSNR 计算,请注意核对各部分所需参数!
```
## 模型描述
### 性能
#### 评估性能
BRDNet on “waterloo5050step40colorimage”
| Parameters | BRDNet |
| -------------------------- | ------------------------------------------------------------------------------ |
| Resource | Ascend 910 CPU 2.60GHz192cores; Memory, 755G |
| uploaded Date | 5/20/2021 (month/day/year) |
| MindSpore Version | master |
| Dataset | waterloo5050step40colorimage |
| Training Parameters | epoch=50, batch_size=32, lr=0.001 |
| Optimizer | Adam |
| Loss Function | MSELoss(reduction='sum') |
| outputs | denoised images |
| Loss | 80.839773 |
| Speed | 8p about 7000FPS to 7400FPS |
| Total time | 8p about 2h 14min |
| Checkpoint for Fine tuning | 8p: 13.68MB , 1p: 19.76MB (.ckpt file) |
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/brdnet |
## 随机情况说明
train.py中设置了随机种子。
## ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,14 @@
cmake_minimum_required(VERSION 3.14.1)
project(Ascend310Infer)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT})
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main src/main.cc src/utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)

View File

@ -0,0 +1,29 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ -d out ]; then
rm -rf out
fi
mkdir out
cd out || exit
if [ -f "Makefile" ]; then
make clean
fi
cmake .. \
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
#endif

View File

@ -0,0 +1,154 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <sstream>
#include <random>
#include <ctime>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "include/minddata/dataset/include/vision_ascend.h"
#include "include/minddata/dataset/include/execute.h"
#include "include/minddata/dataset/include/transforms.h"
#include "include/minddata/dataset/include/constants.h"
#include "include/minddata/dataset/include/vision.h"
#include "inc/utils.h"
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Context;
using mindspore::Status;
using mindspore::ModelType;
using mindspore::Graph;
using mindspore::GraphCell;
using mindspore::kSuccess;
using mindspore::MSTensor;
using mindspore::DataType;
using mindspore::dataset::Execute;
using mindspore::dataset::TensorTransform;
using mindspore::dataset::vision::Decode;
using mindspore::dataset::vision::Resize;
using mindspore::dataset::vision::HWC2CHW;
using mindspore::dataset::vision::Normalize;
using mindspore::dataset::transforms::TypeCast;
DEFINE_string(model_path, "", "model path");
DEFINE_string(dataset_path, "", "dataset path");
DEFINE_int32(input_width, 500, "input width");
DEFINE_int32(input_height, 500, "inputheight");
DEFINE_int32(device_id, 0, "device id");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_model_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310_info->SetDeviceID(FLAGS_device_id);
context->MutableDeviceInfo().push_back(ascend310_info);
Graph graph;
Status ret = Serialization::Load(FLAGS_model_path, ModelType::kMindIR, &graph);
if (ret != kSuccess) {
std::cout << "Load model failed." << std::endl;
return 1;
}
Model model;
ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
std::vector<MSTensor> modelInputs = model.GetInputs();
auto all_files = GetAllFiles(FLAGS_dataset_path);
if (all_files.empty()) {
std::cout << "ERROR: no input data." << std::endl;
return 1;
}
std::map<double, double> costTime_map;
size_t size = all_files.size();
// Define transform
for (size_t i = 0; i < size; ++i) {
struct timeval start;
struct timeval end;
double startTime_ms;
double endTime_ms;
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << all_files[i] << std::endl;
auto image = ReadFileToTensor(all_files[i]);
inputs.emplace_back(modelInputs[0].Name(), modelInputs[0].DataType(), modelInputs[0].Shape(),
image.Data().get(), image.DataSize());
gettimeofday(&start, NULL);
ret = model.Predict(inputs, &outputs);
gettimeofday(&end, NULL);
if (ret != kSuccess) {
std::cout << "Predict " << all_files[i] << " failed." << std::endl;
return 1;
}
startTime_ms = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTime_ms = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTime_ms, endTime_ms));
WriteResult(all_files[i], outputs);
}
double average = 0.0;
int infer_cnt = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
infer_cnt++;
}
average = average / infer_cnt;
std::stringstream timeCost;
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << infer_cnt << std::endl;
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << infer_cnt << std::endl;
std::string file_name = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream file_stream(file_name.c_str(), std::ios::trunc);
file_stream << timeCost.str();
file_stream.close();
costTime_map.clear();
return 0;
}

View File

@ -0,0 +1,127 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/utils.h"
#include <fstream>
#include <algorithm>
#include <iostream>
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE *outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}

View File

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import os
import time
import glob
import math
import numpy as np
import PIL.Image as Image
## Params
parser = argparse.ArgumentParser()
parser.add_argument('--image_path', default='./Test/Kodak24/'
, type=str, help='directory of test dataset')
parser.add_argument('--output_path', default='', type=str,
help='path of the predict files that generated by the model')
parser.add_argument('--image_width', default=500, type=int, help='image_width')
parser.add_argument('--image_height', default=500, type=int, help='image_height')
parser.add_argument('--channel', default=3, type=int
, help='image channel, 3 for color, 1 for gray')
args = parser.parse_args()
def calculate_psnr(image1, image2):
image1 = np.float64(image1)
image2 = np.float64(image2)
diff = image1 - image2
diff = diff.flatten('C')
rmse = math.sqrt(np.mean(diff**2.))
return 20*math.log10(1.0/rmse)
def cal_psnr(output_path, image_path):
file_list = glob.glob(image_path+'*') # image_path must end by '/'
psnr = [] #after denoise
start_time = time.time()
for file in file_list:
filename = file.split('/')[-1].split('.')[0] # get the name of image file
# read image
if args.channel == 3:
img_clean = np.array(Image.open(file), dtype='float32') / 255.0
else:
img_clean = np.expand_dims(np.array(Image.open(file).convert('L'), dtype='float32') / 255.0, axis=2)
result_file = os.path.join(output_path, filename+"_noise_0.bin")
y_predict = np.fromfile(result_file, dtype=np.float32)
y_predict = y_predict.reshape(args.channel, args.image_height, args.image_width)
img_out = y_predict.transpose((1, 2, 0))#HWC
img_out = np.clip(img_out, 0, 1)
psnr_denoised = calculate_psnr(img_clean, img_out)
psnr.append(psnr_denoised)
print(filename, ": psnr_denoised: ", " ", psnr_denoised)
psnr_avg = sum(psnr)/len(psnr)
print("Average PSNR:", psnr_avg)
print("Testing finished....")
time_used = time.time() - start_time
print("Time cost:"+str(time_used)+" seconds!")
if __name__ == '__main__':
cal_psnr(args.output_path, args.image_path)

View File

@ -0,0 +1,193 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import datetime
import argparse
import os
import time
import glob
import pandas as pd
import numpy as np
import PIL.Image as Image
import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore.common import set_seed
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore import load_checkpoint, load_param_into_net
from src.logger import get_logger
from src.models import BRDNet
## Params
parser = argparse.ArgumentParser()
parser.add_argument('--test_dir', default='./Test/Kodak24/'
, type=str, help='directory of test dataset')
parser.add_argument('--sigma', default=15, type=int, help='noise level')
parser.add_argument('--channel', default=3, type=int
, help='image channel, 3 for color, 1 for gray')
parser.add_argument('--pretrain_path', default=None, type=str, help='path of pre-trained model')
parser.add_argument('--ckpt_name', default=None, type=str, help='ckpt_name')
parser.add_argument('--use_modelarts', type=int, default=0
, help='1 for True, 0 for False;when set True, we should load dataset from obs with moxing')
parser.add_argument('--train_url', type=str, default='train_url/'
, help='needed by modelarts, but we donot use it because the name is ambiguous')
parser.add_argument('--data_url', type=str, default='data_url/'
, help='needed by modelarts, but we donot use it because the name is ambiguous')
parser.add_argument('--output_path', type=str, default='./output/'
, help='output_path,when use_modelarts is set True, it will be cache/output/')
parser.add_argument('--outer_path', type=str, default='s3://output/'
, help='obs path,to store e.g ckpt files ')
parser.add_argument('--device_target', type=str, default='Ascend',
help='device where the code will be implemented. (Default: Ascend)')
set_seed(1)
args = parser.parse_args()
save_dir = os.path.join(args.output_path, 'sigma_' + str(args.sigma) + \
'_' + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
if not args.use_modelarts and not os.path.exists(save_dir):
os.makedirs(save_dir)
def test(model_path):
args.logger.info('Start to test on {}'.format(args.test_dir))
out_dir = os.path.join(save_dir, args.test_dir.split('/')[-2]) # args.test_dir must end by '/'
if not args.use_modelarts and not os.path.exists(out_dir):
os.makedirs(out_dir)
model = BRDNet(args.channel)
args.logger.info('load test weights from '+str(model_path))
load_param_into_net(model, load_checkpoint(model_path))
name = []
psnr = [] #after denoise
ssim = [] #after denoise
psnr_b = [] #before denoise
ssim_b = [] #before denoise
if args.use_modelarts:
args.logger.info("copying test dataset from obs to cache....")
mox.file.copy_parallel(args.test_dir, 'cache/test')
args.logger.info("copying test dataset finished....")
args.test_dir = 'cache/test/'
file_list = glob.glob(args.test_dir+'*') # args.test_dir must end by '/'
model.set_train(False)
cast = P.Cast()
transpose = P.Transpose()
expand_dims = P.ExpandDims()
compare_psnr = nn.PSNR()
compare_ssim = nn.SSIM()
args.logger.info("start testing....")
start_time = time.time()
for file in file_list:
suffix = file.split('.')[-1]
# read image
if args.channel == 3:
img_clean = np.array(Image.open(file), dtype='float32') / 255.0
else:
img_clean = np.expand_dims(np.array(Image.open(file).convert('L'), dtype='float32') / 255.0, axis=2)
np.random.seed(0) #obtain the same random data when it is in the test phase
img_test = img_clean + np.random.normal(0, args.sigma/255.0, img_clean.shape)
img_clean = Tensor(img_clean, mindspore.float32) #HWC
img_test = Tensor(img_test, mindspore.float32) #HWC
# predict
img_clean = expand_dims(transpose(img_clean, (2, 0, 1)), 0)#NCHW
img_test = expand_dims(transpose(img_test, (2, 0, 1)), 0)#NCHW
y_predict = model(img_test) #NCHW
# calculate numeric metrics
img_out = C.clip_by_value(y_predict, 0, 1)
psnr_noise, psnr_denoised = compare_psnr(img_clean, img_test), compare_psnr(img_clean, img_out)
ssim_noise, ssim_denoised = compare_ssim(img_clean, img_test), compare_ssim(img_clean, img_out)
psnr.append(psnr_denoised.asnumpy()[0])
ssim.append(ssim_denoised.asnumpy()[0])
psnr_b.append(psnr_noise.asnumpy()[0])
ssim_b.append(ssim_noise.asnumpy()[0])
# save images
filename = file.split('/')[-1].split('.')[0] # get the name of image file
name.append(filename)
if not args.use_modelarts:
# inner the operation 'Image.save', it will first check the file \
# existence of same name, that is not allowed on modelarts
img_test = cast(img_test*255, mindspore.uint8).asnumpy()
img_test = img_test.squeeze(0).transpose((1, 2, 0)) #turn into HWC to save as an image
img_test = Image.fromarray(img_test)
img_test.save(os.path.join(out_dir, filename+'_sigma'+'{}_psnr{:.2f}.'\
.format(args.sigma, psnr_noise.asnumpy()[0])+str(suffix)))
img_out = cast(img_out*255, mindspore.uint8).asnumpy()
img_out = img_out.squeeze(0).transpose((1, 2, 0)) #turn into HWC to save as an image
img_out = Image.fromarray(img_out)
img_out.save(os.path.join(out_dir, filename+'_psnr{:.2f}.'.format(psnr_denoised.asnumpy()[0])+str(suffix)))
psnr_avg = sum(psnr)/len(psnr)
ssim_avg = sum(ssim)/len(ssim)
psnr_avg_b = sum(psnr_b)/len(psnr_b)
ssim_avg_b = sum(ssim_b)/len(ssim_b)
name.append('Average')
psnr.append(psnr_avg)
ssim.append(ssim_avg)
psnr_b.append(psnr_avg_b)
ssim_b.append(ssim_avg_b)
args.logger.info('Before denoise: Average PSNR_b = {0:.2f}, \
SSIM_b = {1:.2f};After denoise: Average PSNR = {2:.2f}, SSIM = {3:.2f}'\
.format(psnr_avg_b, ssim_avg_b, psnr_avg, ssim_avg))
args.logger.info("testing finished....")
time_used = time.time() - start_time
args.logger.info("time cost:"+str(time_used)+" seconds!")
if not args.use_modelarts:
pd.DataFrame({'name': np.array(name), 'psnr_b': np.array(psnr_b), \
'psnr': np.array(psnr), 'ssim_b': np.array(ssim_b), \
'ssim': np.array(ssim)}).to_csv(out_dir+'/metrics.csv', index=True)
if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.device_target, device_id=device_id, save_graphs=False)
args.logger = get_logger(save_dir, "BRDNet", 0)
args.logger.save_args(args)
if args.use_modelarts:
import moxing as mox
args.logger.info("copying test weights from obs to cache....")
mox.file.copy_parallel(args.pretrain_path, 'cache/weight')
args.logger.info("copying test weights finished....")
args.pretrain_path = 'cache/weight/'
test(os.path.join(args.pretrain_path, args.ckpt_name))
if args.use_modelarts:
args.logger.info("copying files from cache to obs....")
mox.file.copy_parallel(save_dir, args.outer_path)
args.logger.info("copying finished....")

View File

@ -0,0 +1,56 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx, mindir models#################
python export.py
"""
import argparse
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.models import BRDNet
## Params
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int, help='batch size')
parser.add_argument('--channel', default=3, type=int
, help='image channel, 3 for color, 1 for gray')
parser.add_argument("--image_height", type=int, default=50, help="Image height.")
parser.add_argument("--image_width", type=int, default=50, help="Image width.")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="brdnet", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument('--device_target', type=str, default='Ascend'
, help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
if __name__ == '__main__':
net = BRDNet(args_opt.channel)
param_dict = load_checkpoint(args_opt.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros([args_opt.batch_size, args_opt.channel, \
args_opt.image_height, args_opt.image_width]), ms.float32)
export(net, input_arr, file_name=args_opt.file_name, file_format=args_opt.file_format)

View File

@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import os
import glob
import numpy as np
import PIL.Image as Image
parser = argparse.ArgumentParser()
parser.add_argument('--out_dir', type=str, required=True,
help='directory to store the image with noise')
parser.add_argument('--image_path', type=str, required=True,
help='directory of image to add noise')
parser.add_argument('--channel', type=int, default=3
, help='image channel, 3 for color, 1 for gray')
parser.add_argument('--sigma', type=int, default=15, help='level of noise')
args = parser.parse_args()
def add_noise(out_dir, image_path, channel, sigma):
file_list = glob.glob(image_path+'*') # image_path must end by '/'
if not os.path.exists(out_dir):
os.makedirs(out_dir)
for file in file_list:
print("Adding noise to: ", file)
# read image
if channel == 3:
img_clean = np.array(Image.open(file), dtype='float32') / 255.0
else:
img_clean = np.expand_dims(np.array(Image.open(file).convert('L'), dtype='float32') / 255.0, axis=2)
np.random.seed(0) #obtain the same random data when it is in the test phase
img_test = img_clean + np.random.normal(0, sigma/255.0, img_clean.shape).astype(np.float32)#HWC
img_test = np.expand_dims(img_test.transpose((2, 0, 1)), 0)#NCHW
#img_test = np.clip(img_test, 0, 1)
filename = file.split('/')[-1].split('.')[0] # get the name of image file
img_test.tofile(os.path.join(out_dir, filename+'_noise.bin'))
if __name__ == "__main__":
add_noise(args.out_dir, args.image_path, args.channel, args.sigma)

View File

@ -0,0 +1,83 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 8 ]; then
echo "Usage: sh run_distribute_train.sh [train_code_path] [train_data]" \
"[batch_size] [sigma] [channel] [epoch] [lr] [rank_table_file_path]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)/"
fi
}
get_real_path_name() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
train_code_path=$(get_real_path $1)
echo $train_code_path
if [ ! -d $train_code_path ]
then
echo "error: train_code_path=$train_code_path is not a dictionary."
exit 1
fi
train_data=$(get_real_path $2)
echo $train_data
if [ ! -d $train_data ]
then
echo "error: train_data=$train_data is not a dictionary."
exit 1
fi
rank_table_file_path=$(get_real_path_name $8)
echo "rank_table_file_path: "$rank_table_file_path
ulimit -c unlimited
export SLOG_PRINT_TO_STDOUT=0
export RANK_TABLE_FILE=$rank_table_file_path
export RANK_SIZE=8
export RANK_START_ID=0
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=${i}
export DEVICE_ID=$((i + RANK_START_ID))
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
if [ -d ${train_code_path}device${DEVICE_ID} ]; then
rm -rf ${train_code_path}device${DEVICE_ID}
fi
mkdir ${train_code_path}device${DEVICE_ID}
cd ${train_code_path}device${DEVICE_ID} || exit
nohup python ${train_code_path}train.py --is_distributed=1 \
--train_data=${train_data} \
--batch_size=$3 \
--sigma=$4 \
--channel=$5 \
--epoch=$6 \
--lr=$7 > ${train_code_path}device${DEVICE_ID}/log.txt 2>&1 &
done

View File

@ -0,0 +1,63 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 6 ]; then
echo "Usage: sh run_eval.sh [train_code_path] [test_dir] [sigma]" \
"[channel] [pretrain_path] [ckpt_name]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)/"
fi
}
train_code_path=$(get_real_path $1)
echo "train_code_path: "$train_code_path
test_dir=$(get_real_path $2)
echo "test_dir: "$test_dir
pretrain_path=$(get_real_path $5)
echo "pretrain_path: "$pretrain_path
if [ ! -d $train_code_path ]
then
echo "error: train_code_path=$train_code_path is not a dictionary."
exit 1
fi
if [ ! -d $test_dir ]
then
echo "error: test_dir=$test_dir is not a dictionary."
exit 1
fi
if [ ! -d $pretrain_path ]
then
echo "error: pretrain_path=$pretrain_path is not a dictionary."
exit 1
fi
python ${train_code_path}eval.py \
--test_dir=$test_dir \
--sigma=$3 \
--channel=$4 \
--pretrain_path=$pretrain_path \
--ckpt_name=$6

View File

@ -0,0 +1,131 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 6 ]; then
echo "Usage: sh run_infer_310.sh [model_path] [data_path]" \
"[noise_image_path] [sigma] [channel] [device_id]"
exit 1
fi
get_real_path_name() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)/"
fi
}
model=$(get_real_path_name $1)
data_path=$(get_real_path $2)
noise_image_path=$(get_real_path $3)
sigma=$4
channel=$5
device_id=$6
echo "model path: "$model
echo "dataset path: "$data_path
echo "noise image path: "$noise_image_path
echo "sigma: "$sigma
echo "channel: "$channel
echo "device id: "$device_id
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function compile_app()
{
cd ../ascend310_infer/ || exit
if [ -f "Makefile" ]; then
make clean
fi
sh build.sh &> build.log
}
function preprocess()
{
cd ../ || exit
echo "\nstart preprocess"
echo "waitting for preprocess finish..."
python3.7 preprocess.py --out_dir=$noise_image_path --image_path=$data_path --channel=$channel --sigma=$sigma > preprocess.log 2>&1
echo "preprocess finished! you can see the log in preprocess.log!"
}
function infer()
{
cd ./scripts || exit
if [ -d result_Files ]; then
rm -rf ./result_Files
fi
if [ -d time_Result ]; then
rm -rf ./time_Result
fi
mkdir result_Files
mkdir time_Result
echo "\nstart infer..."
echo "waitting for infer finish..."
../ascend310_infer/out/main --model_path=$model --dataset_path=$noise_image_path --device_id=$device_id > infer.log 2>&1
echo "infer finished! you can see the log in infer.log!"
}
function cal_psnr()
{
echo "\nstart calculate PSNR..."
echo "waitting for calculate finish..."
python3.7 ../cal_psnr.py --image_path=$data_path --output_path=./result_Files --channel=$channel > psnr.log 2>&1
echo "calculate finished! you can see the log in psnr.log\n"
}
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
preprocess
if [ $? -ne 0 ]; then
echo "execute preprocess failed"
exit 1
fi
infer
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi
cal_psnr
if [ $? -ne 0 ]; then
echo "calculate psnr failed"
exit 1
fi

View File

@ -0,0 +1,57 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 7 ]; then
echo "Usage: sh run_train.sh [train_code_path] [train_data]" \
"[batch_size] [sigma] [channel] [epoch] [lr]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)/"
fi
}
train_code_path=$(get_real_path $1)
echo "train_code_path: "$train_code_path
train_data=$(get_real_path $2)
echo "train_data: "$train_data
if [ ! -d $train_code_path ]
then
echo "error: train_code_path=$train_code_path is not a dictionary."
exit 1
fi
if [ ! -d $train_data ]
then
echo "error: train_data=$train_data is not a dictionary."
exit 1
fi
nohup python ${train_code_path}train.py --is_distributed=0 \
--train_data=${train_data} \
--batch_size=$3 \
--sigma=$4 \
--channel=$5 \
--epoch=$6 \
--lr=$7 > log.txt 2>&1 &
echo 'Train task has been started successfully!'
echo 'Please check the log at log.txt'

View File

@ -0,0 +1,69 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import glob
import numpy as np
import PIL.Image as Image
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
from src.distributed_sampler import DistributedSampler
class BRDNetDataset:
""" BRDNetDataset.
Args:
data_path: path of images, must end by '/'
sigma: noise level
channel: 3 for color, 1 for gray
"""
def __init__(self, data_path, sigma, channel):
images = []
file_dictory = glob.glob(data_path+'*.bmp') #notice the data format
for file in file_dictory:
images.append(file)
self.images = images
self.sigma = sigma
self.channel = channel
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
get_batch_x: image with noise
get_batch_y: image without noise
"""
if self.channel == 3:
get_batch_y = np.array(Image.open(self.images[index]), dtype='uint8')
else:
get_batch_y = np.expand_dims(np.array(Image.open(self.images[index]).convert('L'), dtype='uint8'), axis=2)
get_batch_y = get_batch_y.astype('float32')/255.0
noise = np.random.normal(0, self.sigma/255.0, get_batch_y.shape).astype('float32') # noise
get_batch_x = get_batch_y + noise # input image = clean image + noise
return get_batch_x, get_batch_y
def __len__(self):
return len(self.images)
def create_BRDNetDataset(data_path, sigma, channel, batch_size, device_num, rank, shuffle):
dataset = BRDNetDataset(data_path, sigma, channel)
dataset_len = len(dataset)
distributed_sampler = DistributedSampler(dataset_len, device_num, rank, shuffle=shuffle)
hwc_to_chw = CV.HWC2CHW()
data_set = ds.GeneratorDataset(dataset, column_names=["image", "label"], \
shuffle=shuffle, sampler=distributed_sampler)
data_set = data_set.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8)
data_set = data_set.map(input_columns=["label"], operations=hwc_to_chw, num_parallel_workers=8)
data_set = data_set.batch(batch_size, drop_remainder=True)
return data_set, dataset_len

View File

@ -0,0 +1,60 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BRDNetDataset distributed sampler."""
from __future__ import division
import math
import numpy as np
class DistributedSampler:
"""Distributed sampler."""
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
print("***********Setting world_size to 1 since it is not passed in ******************")
num_replicas = 1
if rank is None:
print("***********Setting rank to 0 since it is not passed in ******************")
rank = 0
self.dataset_size = dataset_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
indices = indices.tolist()
self.epoch += 1
# change to list type
else:
indices = list(range(self.dataset_size))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples

View File

@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Custom Logger."""
import os
import sys
import logging
from datetime import datetime
class LOGGER(logging.Logger):
"""
Logger.
Args:
logger_name: String. Logger name.
rank: Integer. Rank id.
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
self.rank = rank
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""Setup logging file."""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, logger_name, rank):
"""Get Logger."""
logger = LOGGER(logger_name, rank)
logger.setup_logging_file(path, rank)
return logger

View File

@ -0,0 +1,150 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
import mindspore.ops as ops
#from batch_renorm import BatchRenormalization
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
from mindspore.ops import functional as F
class BRDNet(nn.Cell):
"""
args:
channel: 3 for color, 1 for gray
"""
def __init__(self, channel):
super(BRDNet, self).__init__()
self.Conv2d_1 = nn.Conv2d(channel, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)
self.BRN_1 = nn.BatchNorm2d(64, eps=1e-3)
self.layer1 = self.make_layer1(15)
self.Conv2d_2 = nn.Conv2d(64, channel, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)
self.Conv2d_3 = nn.Conv2d(channel, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)
self.BRN_2 = nn.BatchNorm2d(64, eps=1e-3)
self.layer2 = self.make_layer2(7)
self.Conv2d_4 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)
self.BRN_3 = nn.BatchNorm2d(64, eps=1e-3)
self.layer3 = self.make_layer2(6)
self.Conv2d_5 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)
self.BRN_4 = nn.BatchNorm2d(64, eps=1e-3)
self.Conv2d_6 = nn.Conv2d(64, channel, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)
self.Conv2d_7 = nn.Conv2d(channel*2, channel, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)
self.relu = nn.ReLU()
self.sub = ops.Sub()
self.concat = ops.Concat(axis=1)#NCHW
def make_layer1(self, nums):
layers = []
assert nums > 0
for _ in range(nums):
layers.append(nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True))
layers.append(nn.BatchNorm2d(64, eps=1e-3))
layers.append(nn.ReLU())
return nn.SequentialCell(layers)
def make_layer2(self, nums):
layers = []
assert nums > 0
for _ in range(nums):
layers.append(nn.Conv2d(64, 64, kernel_size=(3, 3), \
stride=(1, 1), dilation=(2, 2), pad_mode='same', has_bias=True))
layers.append(nn.ReLU())
return nn.SequentialCell(layers)
def construct(self, inpt):
#inpt-----> 'NCHW'
x = self.Conv2d_1(inpt)
x = self.BRN_1(x)
x = self.relu(x)
# 15 layers, Conv+BN+relu
x = self.layer1(x)
# last layer, Conv
x = self.Conv2d_2(x) #for output channel, gray is 1 color is 3
x = self.sub(inpt, x) # input - noise
y = self.Conv2d_3(inpt)
y = self.BRN_2(y)
y = self.relu(y)
# first Conv+relu's
y = self.layer2(y)
y = self.Conv2d_4(y)
y = self.BRN_3(y)
y = self.relu(y)
# second Conv+relu's
y = self.layer3(y)
y = self.Conv2d_5(y)
y = self.BRN_4(y)
y = self.relu(y)
y = self.Conv2d_6(y)#for output channel, gray is 1 color is 3
y = self.sub(inpt, y) # input - noise
o = self.concat((x, y))
z = self.Conv2d_7(o)#gray is 1 color is 3
z = self.sub(inpt, z)
return z
class BRDWithLossCell(nn.Cell):
def __init__(self, network):
super(BRDWithLossCell, self).__init__()
self.network = network
self.loss = nn.MSELoss(reduction='sum') #we use 'sum' instead of 'mean' to avoid the loss becoming too small
def construct(self, images, targets):
output = self.network(images)
return self.loss(output, targets)
class TrainingWrapper(nn.Cell):
"""Training wrapper."""
def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num")
else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))

View File

@ -0,0 +1,182 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import datetime
import argparse
import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore.train import Model
from mindspore.common import set_seed
from mindspore.context import ParallelMode
from mindspore.common.tensor import Tensor
from mindspore.train.callback import TimeMonitor, LossMonitor
from mindspore import load_checkpoint, load_param_into_net
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.communication.management import init, get_rank, get_group_size
from src.logger import get_logger
from src.dataset import create_BRDNetDataset
from src.models import BRDNet, BRDWithLossCell, TrainingWrapper
## Params
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int, help='batch size')
parser.add_argument('--train_data', default='../dataset/waterloo5050step40colorimage/'
, type=str, help='path of train data')
parser.add_argument('--sigma', default=75, type=int, help='noise level')
parser.add_argument('--channel', default=3, type=int
, help='image channel, 3 for color, 1 for gray')
parser.add_argument('--epoch', default=50, type=int, help='number of train epoches')
parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate for Adam')
parser.add_argument('--save_every', default=1, type=int, help='save model at every x epoches')
parser.add_argument('--pretrain', default=None, type=str, help='path of pre-trained model')
parser.add_argument('--use_modelarts', type=int, default=0
, help='1 for True, 0 for False; when set True, we should load dataset from obs with moxing')
parser.add_argument('--train_url', type=str, default='train_url/'
, help='needed by modelarts, but we donot use it because the name is ambiguous')
parser.add_argument('--data_url', type=str, default='data_url/'
, help='needed by modelarts, but we donot use it because the name is ambiguous')
parser.add_argument('--output_path', type=str, default='./output/'
, help='output_path,when use_modelarts is set True, it will be cache/output/')
parser.add_argument('--outer_path', type=str, default='s3://output/'
, help='obs path,to store e.g ckpt files ')
parser.add_argument('--device_target', type=str, default='Ascend'
, help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
parser.add_argument('--ckpt_save_max', type=int, default=5
, help='Maximum number of checkpoint files can be saved. Default: 5.')
set_seed(1)
args = parser.parse_args()
save_dir = os.path.join(args.output_path, 'sigma_' + str(args.sigma) \
+ '_' + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
if not args.use_modelarts and not os.path.exists(save_dir):
os.makedirs(save_dir)
def get_lr(steps_per_epoch, max_epoch, init_lr):
lr_each_step = []
while max_epoch > 0:
tem = min(30, max_epoch)
for _ in range(steps_per_epoch*tem):
lr_each_step.append(init_lr)
max_epoch -= tem
init_lr /= 10
return lr_each_step
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.device_target, save_graphs=False)
def train():
if args.is_distributed:
assert args.device_target == "Ascend"
init()
context.set_context(device_id=device_id)
args.rank = get_rank()
args.group_size = get_group_size()
device_num = args.group_size
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
else:
if args.device_target == "Ascend":
context.set_context(device_id=device_id)
# select for master rank save ckpt or all rank save, compatible for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
args.logger = get_logger(save_dir, "BRDNet", args.rank)
args.logger.save_args(args)
if args.use_modelarts:
import moxing as mox
args.logger.info("copying train data from obs to cache....")
mox.file.copy_parallel(args.train_data, 'cache/dataset') # the args.train_data must end by '/'
args.logger.info("copying traindata finished....")
args.train_data = 'cache/dataset/' # the args.train_data must end by '/'
dataset, batch_num = create_BRDNetDataset(args.train_data, args.sigma, \
args.channel, args.batch_size, args.group_size, args.rank, shuffle=True)
args.steps_per_epoch = int(batch_num / args.batch_size / args.group_size)
model = BRDNet(args.channel)
if args.pretrain:
if args.use_modelarts:
import moxing as mox
args.logger.info("copying pretrain model from obs to cache....")
mox.file.copy_parallel(args.pretrain, 'cache/pretrain')
args.logger.info("copying pretrain model finished....")
args.pretrain = 'cache/pretrain/'+args.pretrain.split('/')[-1]
args.logger.info('loading pre-trained model {} into network'.format(args.pretrain))
load_param_into_net(model, load_checkpoint(args.pretrain))
args.logger.info('loaded pre-trained model {} into network'.format(args.pretrain))
model = BRDWithLossCell(model)
model.set_train()
lr_list = get_lr(args.steps_per_epoch, args.epoch, args.lr)
optimizer = nn.Adam(params=model.trainable_params(), learning_rate=Tensor(lr_list, mindspore.float32))
model = TrainingWrapper(model, optimizer)
model = Model(model)
# define callbacks
if args.rank == 0:
time_cb = TimeMonitor(data_size=batch_num)
loss_cb = LossMonitor(per_print_times=10)
callbacks = [time_cb, loss_cb]
else:
callbacks = None
if args.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch*args.save_every,
keep_checkpoint_max=args.ckpt_save_max)
save_ckpt_path = os.path.join(save_dir, 'ckpt_' + str(args.rank) + '/')
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix='channel_'+str(args.channel)+'_sigma_'+str(args.sigma)+'_rank_'+str(args.rank))
callbacks.append(ckpt_cb)
model.train(args.epoch, dataset, callbacks=callbacks, dataset_sink_mode=True)
args.logger.info("training finished....")
if args.use_modelarts:
args.logger.info("copying files from cache to obs....")
mox.file.copy_parallel(save_dir, args.outer_path)
args.logger.info("copying finished....")
if __name__ == '__main__':
train()
args.logger.info('All task finished!')