diff --git a/model_zoo/official/cv/brdnet/README_CN.md b/model_zoo/official/cv/brdnet/README_CN.md new file mode 100644 index 00000000000..8d5124dcdde --- /dev/null +++ b/model_zoo/official/cv/brdnet/README_CN.md @@ -0,0 +1,532 @@ +# BRDNet + + + +- [BRDNet](#BRDNet) +- [BRDNet介绍](#BRDNet介绍) +- [模型结构](#模型结构) +- [数据集](#数据集) +- [环境要求](#环境要求) +- [快速入门](#快速入门) +- [脚本说明](#脚本说明) + - [脚本及样例代码](#脚本及样例代码) + - [脚本参数](#脚本参数) + - [训练过程](#训练过程) + - [训练](#训练) + - [评估过程](#评估过程) + - [评估](#评估) + - [310推理](#310推理) +- [模型描述](#模型描述) + - [性能](#性能) + - [评估性能](#评估性能) +- [随机情况说明](#随机情况说明) +- [ModelZoo主页](#modelzoo主页) + + + +## BRDNet介绍 + +BRDNet 于 2020 年发表在人工智能期刊 Neural Networks,并被该期刊评为 2019/2020 年最高下载量之一的论文。该文提出了在硬件资源受限条件处理数据分步不均匀问题的方案,同时首次提出利用双路网络提取互补信息的思路来进行图像去噪,能有效去除合成噪声和真实噪声。 + +[论文](https://www.sciencedirect.com/science/article/pii/S0893608019302394):Ct A , Yong X , Wz C . Image denoising using deep CNN with batch renormalization[J]. Neural Networks, 2020, 121:461-473. + +## 模型结构 + +BRDNet 包含上下两个分支。上分支仅仅包含残差学习与 BRN;下分支包含 BRN、残差学习以及扩张卷积。 + +上分支网络包含两种不同类型的层:Conv+BRN+ReLU 与 Conv。它的深度为 17,前 16 个层为 Conv+BRN+ReLU, 最后一层为 Conv。特征图通道数为 64,卷积核尺寸为 3. +下分支网络同样包含 17 层,但其 1、9、16 层为 Conv+BRN+ReLU,2-8、10-15 为扩张卷积,最后一层为 Conv。卷积核尺寸为 3,通道数为 64. + +两个分支结果经由 `Concat` 组合并经 Conv 得到噪声,最后采用原始输入减去该噪声即可得到清晰无噪图像。整个 BRDNet 合计包含 18 层,相对比较浅,不会导致梯度消失与梯度爆炸问题。 + +## 数据集 + +训练数据集:[color noisy]() + +去除高斯噪声时,数据集选用源于 Waterloo Exploration Database 的 3,859 张图片,并预处理为 50x50 的小图片,共计 1,166,393 张。 + +测试数据集:CBSD68, Kodak24, and McMaster + +## 环境要求 + +- 硬件(Ascend/ModelArts) + - 准备Ascend或ModelArts处理器搭建硬件环境。 +- 框架 + - [MindSpore](https://www.mindspore.cn/install) +- 如需查看详情,请参见如下资源: + - [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.html) + +## 快速入门 + +通过官方网站安装 MindSpore 后,您可以按照如下步骤进行训练和评估: + +```shell +#通过 python 命令行运行单卡训练脚本。 请注意 train_data 需要以"/"结尾 +python train.py \ +--train_data=xxx/dataset/waterloo5050step40colorimage/ \ +--sigma=15 \ +--channel=3 \ +--batch_size=32 \ +--lr=0.001 \ +--use_modelarts=0 \ +--output_path=./output/ \ +--is_distributed=0 \ +--epoch=50 > log.txt 2>&1 & + +#通过 sh 命令启动单卡训练。(对 train_data 等参数的路径格式无要求,内部会自动转为绝对路径以及以"/"结尾) +sh ./scripts/run_train.sh [train_code_path] [train_data] [batch_size] [sigma] [channel] [epoch] [lr] + +#Ascend多卡训练。 +sh run_distribute_train.sh [train_code_path] [train_data] [batch_size] [sigma] [channel] [epoch] [lr] [rank_table_file_path] + +# 通过 python 命令行运行推理脚本。请注意 test_dir 需要以"/"结尾; +# pretrain_path 指 ckpt 所在目录,为了兼容 modelarts,将其拆分为了 “路径” 与 “文件名” +python eval.py \ +--test_dir=xxx/Test/Kodak24/ \ +--sigma=15 \ +--channel=3 \ +--pretrain_path=xxx/ \ +--ckpt_name=channel_3_sigma_15_rank_0-50_227800.ckpt \ +--use_modelarts=0 \ +--output_path=./output/ \ +--is_distributed=0 > log.txt 2>&1 & + +#通过 sh 命令启动推理。(对 test_dir 等参数的路径格式无要求,内部会自动转为绝对路径以及以"/"结尾) +sh run_eval.sh [train_code_path] [test_dir] [sigma] [channel] [pretrain_path] [ckpt_name] +``` + +Ascend训练:生成[RANK_TABLE_FILE](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools) + +## 脚本说明 + +### 脚本及样例代码 + +```shell +├── model_zoo + ├── README.md // 所有模型的说明文件 + ├── brdnet + ├── README.md // brdnet 的说明文件 + ├── ascend310_infer // 310 推理代码目录(C++) + │ ├──inc + │ │ ├──utils.h // 工具包头文件 + │ ├──src + │ │ ├──main.cc // 310推理代码 + │ │ ├──utils.cc // 工具包 + │ ├──build.sh // 代码编译脚本 + │ ├──CMakeLists.txt // 代码编译设置 + ├── scripts + │ ├──run_distribute_train.sh // Ascend 8卡训练脚本 + │ ├──run_eval.sh // 推理启动脚本 + │ ├──run_train.sh // 训练启动脚本 + │ ├──run_infer_310.sh // 启动310推理的脚本 + ├── src + │ ├──dataset.py // 数据集处理 + │ ├──distributed_sampler.py // 8卡并行时的数据集切分操作 + │ ├──logger.py // 日志打印文件 + │ ├──models.py // 模型结构 + ├── export.py // 将权重文件导出为 MINDIR 等格式的脚本 + ├── train.py // 训练脚本 + ├── eval.py // 推理脚本 + ├── cal_psnr.py // 310推理时计算最终PSNR值的脚本 + ├── preprocess.py // 310推理时为测试图片添加噪声的脚本 +``` + +### 脚本参数 + +```shell +train.py 中的主要参数如下: +--batch_size: 批次大小 +--train_data: 训练数据集路径(必须以"/"结尾)。 +--sigma: 高斯噪声强度 +--channel: 训练类型(3:彩色图;1:灰度图) +--epoch: 训练次数 +--lr: 初始学习率 +--save_every: 权重保存频率(每 N 个 epoch 保存一次) +--pretrain: 预训练文件(接着该文件继续训练) +--use_modelarts: 是否使用 modelarts(1 for True, 0 for False; 设置为 1 时将使用 moxing 从 obs 拷贝数据) +--train_url: ( modelsarts 需要的参数,但因该名称存在歧义而在代码中未使用) +--data_url: ( modelsarts 需要的参数,但因该名称存在歧义而在代码中未使用) +--output_path: 日志等文件输出目录 +--outer_path: 输出到 obs 外部的目录(仅在 modelarts 上运行时有效) +--device_target: 运行设备(默认 "Ascend") +--is_distributed: 是否多卡运行 +--rank: Local rank of distributed. Default: 0 +--group_size: World size of device. Default: 1 +--is_save_on_master: 是否仅保存 0 卡上的运行结果 +--ckpt_save_max: ckpt 保存的最多文件数 + +eval.py 中的主要参数如下: +--test_dir: 测试数据集路径(必须以"/"结尾)。 +--sigma: 高斯噪声强度 +--channel: 推理类型(3:彩色图;1:灰度图) +--pretrain_path: 权重文件路径 +--ckpt_name: 权重文件名称 +--use_modelarts: 是否使用 modelarts(1 for True, 0 for False; 设置为 1 时将使用 moxing 从 obs 拷贝数据) +--train_url: ( modelsarts 需要的参数,但因该名称存在歧义而在代码中未使用) +--data_url: ( modelsarts 需要的参数,但因该名称存在歧义而在代码中未使用) +--output_path: 日志等文件输出目录 +--outer_path: 输出到 obs 外部的目录(仅在 modelarts 上运行时有效) +--device_target: 运行设备(默认 "Ascend") + +export.py 中的主要参数如下: +--batch_size: 批次大小 +--channel: 训练类型(3:彩色图;1:灰度图) +--image_height: 图片高度 +--image_width: 图片宽度 +--ckpt_file: 权重文件路径 +--file_name: 权重文件名称 +--file_format: 待转文件格式,choices=["AIR", "ONNX", "MINDIR"] +--device_target: 运行设备(默认 "Ascend") +--device_id: 运行设备id + +preprocess.py 中的主要参数如下: +--out_dir: 保存噪声图片的路径 +--image_path: 测试图片路径(必须以"/"结尾) +--channel: 图片通道(3:彩色图;1:灰度图) +--sigma: 噪声强度 + +cal_psnr.py 中的主要参数如下: +--image_path: 测试图片路径(必须以"/"结尾) +--output_path: 模型推理完成后的结果保存路径,默认为 xx/scripts/result_Files/ +--image_width: 图片宽度 +--image_height: 图片高度 +--channel:图片通道(3:彩色图;1:灰度图) +``` + +### 训练过程 + +#### 训练 + +- Ascend处理器环境运行 + + ```shell + #通过 python 命令行运行单卡训练脚本。 请注意 train_data 需要以"/"结尾 + python train.py \ + --train_data=xxx/dataset/waterloo5050step40colorimage/ \ + --sigma=15 \ + --channel=3 \ + --batch_size=32 \ + --lr=0.001 \ + --use_modelarts=0 \ + --output_path=./output/ \ + --is_distributed=0 \ + --epoch=50 > log.txt 2>&1 & + + #通过 sh 命令启动单卡训练。(对 train_data 等参数的路径格式无要求,内部会自动转为绝对路径以及以"/"结尾) + sh ./scripts/run_train.sh [train_code_path] [train_data] [batch_size] [sigma] [channel] [epoch] [lr] + + #上述命令均会使脚本在后台运行,日志将输出到 log.txt,可通过查看该文件了解训练详情 + + #Ascend多卡训练(2、4、8卡配置请自行修改run_distribute_train.sh,默认8卡) + sh run_distribute_train.sh [train_code_path] [train_data] [batch_size] [sigma] [channel] [epoch] [lr] [rank_table_file_path] + ``` + + 注意:第一次运行时可能会较长时间停留在如下界面,这是因为当一个 epoch 运行完成后才会打印日志,请耐心等待。 + + 单卡运行时第一个 epoch 预计耗时 20 ~ 30 分钟。 + + ```python + 2021-05-16 20:12:17,888:INFO:Args: + 2021-05-16 20:12:17,888:INFO:--> batch_size: 32 + 2021-05-16 20:12:17,888:INFO:--> train_data: ../dataset/waterloo5050step40colorimage/ + 2021-05-16 20:12:17,889:INFO:--> sigma: 15 + 2021-05-16 20:12:17,889:INFO:--> channel: 3 + 2021-05-16 20:12:17,889:INFO:--> epoch: 50 + 2021-05-16 20:12:17,889:INFO:--> lr: 0.001 + 2021-05-16 20:12:17,889:INFO:--> save_every: 1 + 2021-05-16 20:12:17,889:INFO:--> pretrain: None + 2021-05-16 20:12:17,889:INFO:--> use_modelarts: False + 2021-05-16 20:12:17,889:INFO:--> train_url: train_url/ + 2021-05-16 20:12:17,889:INFO:--> data_url: data_url/ + 2021-05-16 20:12:17,889:INFO:--> output_path: ./output/ + 2021-05-16 20:12:17,889:INFO:--> outer_path: s3://output/ + 2021-05-16 20:12:17,889:INFO:--> device_target: Ascend + 2021-05-16 20:12:17,890:INFO:--> is_distributed: 0 + 2021-05-16 20:12:17,890:INFO:--> rank: 0 + 2021-05-16 20:12:17,890:INFO:--> group_size: 1 + 2021-05-16 20:12:17,890:INFO:--> is_save_on_master: 1 + 2021-05-16 20:12:17,890:INFO:--> ckpt_save_max: 5 + 2021-05-16 20:12:17,890:INFO:--> rank_save_ckpt_flag: 1 + 2021-05-16 20:12:17,890:INFO:--> logger: + 2021-05-16 20:12:17,890:INFO: + ``` + + 训练完成后,您可以在 --output_path 参数指定的目录下找到保存的权重文件,训练过程中的部分 loss 收敛情况如下: + + ```python + # grep "epoch time:" log.txt + epoch time: 1197471.061 ms, per step time: 32.853 ms + epoch time: 1136826.065 ms, per step time: 31.189 ms + epoch time: 1136840.334 ms, per step time: 31.190 ms + epoch time: 1136837.709 ms, per step time: 31.190 ms + epoch time: 1137081.757 ms, per step time: 31.197 ms + epoch time: 1136830.581 ms, per step time: 31.190 ms + epoch time: 1136845.253 ms, per step time: 31.190 ms + epoch time: 1136881.960 ms, per step time: 31.191 ms + epoch time: 1136850.673 ms, per step time: 31.190 ms + epoch: 10 step: 36449, loss is 103.104095 + epoch time: 1137098.407 ms, per step time: 31.197 ms + epoch time: 1136794.613 ms, per step time: 31.189 ms + epoch time: 1136742.922 ms, per step time: 31.187 ms + epoch time: 1136842.009 ms, per step time: 31.190 ms + epoch time: 1136792.705 ms, per step time: 31.189 ms + epoch time: 1137056.362 ms, per step time: 31.196 ms + epoch time: 1136863.373 ms, per step time: 31.191 ms + epoch time: 1136842.938 ms, per step time: 31.190 ms + epoch time: 1136839.011 ms, per step time: 31.190 ms + epoch time: 1136879.794 ms, per step time: 31.191 ms + epoch: 20 step: 36449, loss is 61.104546 + epoch time: 1137035.395 ms, per step time: 31.195 ms + epoch time: 1136830.626 ms, per step time: 31.190 ms + epoch time: 1136862.117 ms, per step time: 31.190 ms + epoch time: 1136812.265 ms, per step time: 31.189 ms + epoch time: 1136821.096 ms, per step time: 31.189 ms + epoch time: 1137050.310 ms, per step time: 31.196 ms + epoch time: 1136815.292 ms, per step time: 31.189 ms + epoch time: 1136817.757 ms, per step time: 31.189 ms + epoch time: 1136876.477 ms, per step time: 31.191 ms + epoch time: 1136798.538 ms, per step time: 31.189 ms + epoch: 30 step: 36449, loss is 116.179596 + epoch time: 1136972.930 ms, per step time: 31.194 ms + epoch time: 1136825.174 ms, per step time: 31.189 ms + epoch time: 1136798.900 ms, per step time: 31.189 ms + epoch time: 1136828.101 ms, per step time: 31.190 ms + epoch time: 1136862.983 ms, per step time: 31.191 ms + epoch time: 1136989.445 ms, per step time: 31.194 ms + epoch time: 1136688.820 ms, per step time: 31.186 ms + epoch time: 1136858.111 ms, per step time: 31.190 ms + epoch time: 1136822.853 ms, per step time: 31.189 ms + epoch time: 1136782.455 ms, per step time: 31.188 ms + epoch: 40 step: 36449, loss is 70.95368 + epoch time: 1137042.689 ms, per step time: 31.195 ms + epoch time: 1136797.706 ms, per step time: 31.189 ms + epoch time: 1136817.007 ms, per step time: 31.189 ms + epoch time: 1136861.577 ms, per step time: 31.190 ms + epoch time: 1136698.149 ms, per step time: 31.186 ms + epoch time: 1137052.034 ms, per step time: 31.196 ms + epoch time: 1136809.339 ms, per step time: 31.189 ms + epoch time: 1136851.343 ms, per step time: 31.190 ms + epoch time: 1136761.354 ms, per step time: 31.188 ms + epoch time: 1136837.762 ms, per step time: 31.190 ms + epoch: 50 step: 36449, loss is 87.13184 + epoch time: 1137022.554 ms, per step time: 31.195 ms + 2021-05-19 14:24:52,695:INFO:training finished.... + ... + ``` + + 8 卡并行时的 loss 收敛情况: + + ```python + epoch time: 217708.130 ms, per step time: 47.785 ms + epoch time: 144899.598 ms, per step time: 31.804 ms + epoch time: 144736.054 ms, per step time: 31.768 ms + epoch time: 144737.085 ms, per step time: 31.768 ms + epoch time: 144738.102 ms, per step time: 31.769 ms + epoch: 5 step: 4556, loss is 106.67432 + epoch time: 144905.830 ms, per step time: 31.805 ms + epoch time: 144736.539 ms, per step time: 31.768 ms + epoch time: 144734.210 ms, per step time: 31.768 ms + epoch time: 144734.415 ms, per step time: 31.768 ms + epoch time: 144736.405 ms, per step time: 31.768 ms + epoch: 10 step: 4556, loss is 94.092865 + epoch time: 144921.081 ms, per step time: 31.809 ms + epoch time: 144735.718 ms, per step time: 31.768 ms + epoch time: 144737.036 ms, per step time: 31.768 ms + epoch time: 144737.733 ms, per step time: 31.769 ms + epoch time: 144738.251 ms, per step time: 31.769 ms + epoch: 15 step: 4556, loss is 99.18075 + epoch time: 144921.945 ms, per step time: 31.809 ms + epoch time: 144734.948 ms, per step time: 31.768 ms + epoch time: 144735.662 ms, per step time: 31.768 ms + epoch time: 144733.871 ms, per step time: 31.768 ms + epoch time: 144734.722 ms, per step time: 31.768 ms + epoch: 20 step: 4556, loss is 92.54497 + epoch time: 144907.430 ms, per step time: 31.806 ms + epoch time: 144735.713 ms, per step time: 31.768 ms + epoch time: 144733.781 ms, per step time: 31.768 ms + epoch time: 144736.005 ms, per step time: 31.768 ms + epoch time: 144734.331 ms, per step time: 31.768 ms + epoch: 25 step: 4556, loss is 90.98991 + epoch time: 144911.420 ms, per step time: 31.807 ms + epoch time: 144734.535 ms, per step time: 31.768 ms + epoch time: 144734.851 ms, per step time: 31.768 ms + epoch time: 144736.346 ms, per step time: 31.768 ms + epoch time: 144734.939 ms, per step time: 31.768 ms + epoch: 30 step: 4556, loss is 114.33954 + epoch time: 144915.434 ms, per step time: 31.808 ms + epoch time: 144737.336 ms, per step time: 31.769 ms + epoch time: 144733.943 ms, per step time: 31.768 ms + epoch time: 144734.587 ms, per step time: 31.768 ms + epoch time: 144735.043 ms, per step time: 31.768 ms + epoch: 35 step: 4556, loss is 97.21166 + epoch time: 144912.719 ms, per step time: 31.807 ms + epoch time: 144734.795 ms, per step time: 31.768 ms + epoch time: 144733.824 ms, per step time: 31.768 ms + epoch time: 144735.946 ms, per step time: 31.768 ms + epoch time: 144734.930 ms, per step time: 31.768 ms + epoch: 40 step: 4556, loss is 82.41978 + epoch time: 144901.017 ms, per step time: 31.804 ms + epoch time: 144735.060 ms, per step time: 31.768 ms + epoch time: 144733.657 ms, per step time: 31.768 ms + epoch time: 144732.592 ms, per step time: 31.767 ms + epoch time: 144731.292 ms, per step time: 31.767 ms + epoch: 45 step: 4556, loss is 77.92129 + epoch time: 144909.250 ms, per step time: 31.806 ms + epoch time: 144732.944 ms, per step time: 31.768 ms + epoch time: 144733.161 ms, per step time: 31.768 ms + epoch time: 144732.912 ms, per step time: 31.768 ms + epoch time: 144733.709 ms, per step time: 31.768 ms + epoch: 50 step: 4556, loss is 85.499596 + 2021-05-19 02:44:44,219:INFO:training finished.... + ``` + +### 评估过程 + +#### 评估 + +在运行以下命令之前,请检查用于推理评估的权重文件路径是否正确。 + +- Ascend处理器环境运行 + + ```python + #通过 python 命令启动评估,请注意 test_dir 需要以"/" 结尾 + python eval.py \ + --test_dir=./Test/CBSD68/ \ + --sigma=15 \ + --channel=3 \ + --pretrain_path=./ \ + --ckpt_name=channel_3_sigma_15_rank_0-50_227800.ckpt \ + --use_modelarts=0 \ + --output_path=./output/ \ + --is_distributed=0 > log.txt 2>&1 & + + #通过 sh 命令启动评估 (对 test_dir 等参数的路径格式无要求,内部会自动转为绝对路径以及以"/"结尾) + sh run_eval.sh [train_code_path] [test_dir] [sigma] [channel] [pretrain_path] [ckpt_name] + ``` + + ```python + 2021-05-17 13:40:45,909:INFO:Start to test on ./Test/CBSD68/ + 2021-05-17 13:40:46,447:INFO:load test weights from channel_3_sigma_15_rank_0-50_227800.ckpt + 2021-05-17 13:41:52,164:INFO:Before denoise: Average PSNR_b = 24.62, SSIM_b = 0.56;After denoise: Average PSNR = 34.05, SSIM = 0.94 + 2021-05-17 13:41:52,207:INFO:testing finished.... + ``` + +评估完成后,您可以在 --output_path 参数指定的目录下找到 加高斯噪声后的图片和经过模型去除高斯噪声后的图片,图片命名方式代表了处理结果。例如 00001_sigma15_psnr24.62.bmp 是加噪声后的图片(加噪声后 psnr=24.62),00001_psnr31.18.bmp 是去噪声后的图片(去噪后 psnr=31.18)。 + +另外,该文件夹下的 metrics.csv 文件详细记录了对每张测试图片的处理结果,如下所示,psnr_b 是去噪前的 psnr 值,psnr 是去噪后的psnr 值;ssim 指标同理。 + +| | name | psnr_b | psnr | ssim_b | ssim | +| ---- | ------- | ----------- | ----------- | ----------- | ----------- | +| 0 | 1 | 24.61875916 | 31.17827606 | 0.716650724 | 0.910416007 | +| 1 | 2 | 24.61875916 | 35.12858963 | 0.457143694 | 0.995960176 | +| 2 | 3 | 24.61875916 | 34.90437698 | 0.465185702 | 0.935821533 | +| 3 | 4 | 24.61875916 | 35.59785461 | 0.49323535 | 0.941600204 | +| 4 | 5 | 24.61875916 | 32.9185257 | 0.605194688 | 0.958840668 | +| 5 | 6 | 24.61875916 | 37.29947662 | 0.368243992 | 0.962466478 | +| 6 | 7 | 24.61875916 | 33.59238052 | 0.622622728 | 0.930195987 | +| 7 | 8 | 24.61875916 | 31.76290894 | 0.680295587 | 0.918859363 | +| 8 | 9 | 24.61875916 | 34.13358688 | 0.55876708 | 0.939204693 | +| 9 | 10 | 24.61875916 | 34.49848557 | 0.503289104 | 0.928179622 | +| 10 | 11 | 24.61875916 | 34.38597107 | 0.656857133 | 0.961226702 | +| 11 | 12 | 24.61875916 | 32.75747299 | 0.627940595 | 0.910765707 | +| 12 | 13 | 24.61875916 | 34.52487564 | 0.54259634 | 0.936489582 | +| 13 | 14 | 24.61875916 | 35.40441132 | 0.44824928 | 0.93462956 | +| 14 | 15 | 24.61875916 | 32.72385788 | 0.61768961 | 0.91652298 | +| 15 | 16 | 24.61875916 | 33.59120178 | 0.703662276 | 0.948698342 | +| 16 | 17 | 24.61875916 | 36.85597229 | 0.365240872 | 0.940135658 | +| 17 | 18 | 24.61875916 | 37.23021317 | 0.366332233 | 0.953653395 | +| 18 | 19 | 24.61875916 | 33.49061584 | 0.546713233 | 0.928890586 | +| 19 | 20 | 24.61875916 | 33.98015213 | 0.463814735 | 0.938398063 | +| 20 | 21 | 24.61875916 | 32.15977859 | 0.714740098 | 0.945747674 | +| 21 | 22 | 24.61875916 | 32.39984512 | 0.716880679 | 0.930429876 | +| 22 | 23 | 24.61875916 | 34.22258759 | 0.569748521 | 0.945626318 | +| 23 | 24 | 24.61875916 | 33.974823 | 0.603115499 | 0.941333234 | +| 24 | 25 | 24.61875916 | 34.87198639 | 0.486003697 | 0.966141582 | +| 25 | 26 | 24.61875916 | 33.2747879 | 0.593207896 | 0.917522907 | +| 26 | 27 | 24.61875916 | 34.67901611 | 0.504613101 | 0.921615481 | +| 27 | 28 | 24.61875916 | 37.70562363 | 0.331322074 | 0.977765024 | +| 28 | 29 | 24.61875916 | 31.08887672 | 0.759773433 | 0.958483219 | +| 29 | 30 | 24.61875916 | 34.48878479 | 0.502000451 | 0.915705442 | +| 30 | 31 | 24.61875916 | 30.5480938 | 0.836367846 | 0.949165165 | +| 31 | 32 | 24.61875916 | 32.08041382 | 0.745214283 | 0.941719413 | +| 32 | 33 | 24.61875916 | 33.65553284 | 0.556162357 | 0.963605523 | +| 33 | 34 | 24.61875916 | 36.87154388 | 0.384932011 | 0.93150568 | +| 34 | 35 | 24.61875916 | 33.03263474 | 0.586027861 | 0.924151421 | +| 35 | 36 | 24.61875916 | 31.80633736 | 0.572878599 | 0.84426564 | +| 36 | 37 | 24.61875916 | 33.26797485 | 0.526310742 | 0.938789487 | +| 37 | 38 | 24.61875916 | 33.71062469 | 0.554955184 | 0.914420724 | +| 38 | 39 | 24.61875916 | 37.3455925 | 0.461908668 | 0.956513464 | +| 39 | 40 | 24.61875916 | 33.92232895 | 0.554454744 | 0.89727515 | +| 40 | 41 | 24.61875916 | 33.05244827 | 0.590977669 | 0.931611121 | +| 41 | 42 | 24.61875916 | 34.60203552 | 0.492371827 | 0.927084684 | +| 42 | 43 | 24.61875916 | 35.20042419 | 0.535991669 | 0.949365258 | +| 43 | 44 | 24.61875916 | 33.47367096 | 0.614959836 | 0.954348624 | +| 44 | 45 | 24.61875916 | 37.65309143 | 0.363631308 | 0.944297135 | +| 45 | 46 | 24.61875916 | 31.95152092 | 0.709732175 | 0.924522877 | +| 46 | 47 | 24.61875916 | 31.9910202 | 0.70427531 | 0.932488263 | +| 47 | 48 | 24.61875916 | 34.96608353 | 0.585813344 | 0.969006479 | +| 48 | 49 | 24.61875916 | 35.39241409 | 0.388898522 | 0.923762918 | +| 49 | 50 | 24.61875916 | 32.11050415 | 0.653521299 | 0.938310325 | +| 50 | 51 | 24.61875916 | 33.54981995 | 0.594990134 | 0.927819192 | +| 51 | 52 | 24.61875916 | 35.79096603 | 0.371685684 | 0.922166049 | +| 52 | 53 | 24.61875916 | 35.10015869 | 0.410564244 | 0.895557165 | +| 53 | 54 | 24.61875916 | 34.12319565 | 0.591762364 | 0.925524533 | +| 54 | 55 | 24.61875916 | 32.79537964 | 0.653338313 | 0.92444253 | +| 55 | 56 | 24.61875916 | 29.90909004 | 0.826190114 | 0.943322361 | +| 56 | 57 | 24.61875916 | 33.23035812 | 0.527200282 | 0.943572938 | +| 57 | 58 | 24.61875916 | 34.56663132 | 0.409658968 | 0.898451686 | +| 58 | 59 | 24.61875916 | 34.43690109 | 0.454208463 | 0.904649734 | +| 59 | 60 | 24.61875916 | 35.0402565 | 0.409306735 | 0.902388573 | +| 60 | 61 | 24.61875916 | 34.91940308 | 0.443635762 | 0.911728501 | +| 61 | 62 | 24.61875916 | 38.25325394 | 0.42568323 | 0.965163887 | +| 62 | 63 | 24.61875916 | 32.07671356 | 0.727443576 | 0.94306612 | +| 63 | 64 | 24.61875916 | 31.72690964 | 0.671929657 | 0.902075231 | +| 64 | 65 | 24.61875916 | 33.47768402 | 0.533677042 | 0.922906399 | +| 65 | 66 | 24.61875916 | 42.14694977 | 0.266868591 | 0.991976082 | +| 66 | 67 | 24.61875916 | 30.81770706 | 0.84768647 | 0.957114518 | +| 67 | 68 | 24.61875916 | 32.24455261 | 0.623004258 | 0.97843051 | +| 68 | Average | 24.61875916 | 34.05390495 | 0.555872787 | 0.935704286 | + +### 310推理 + +- 在 Ascend 310 处理器环境运行 + + ```python + #通过 sh 命令启动推理 + sh run_infer_310.sh [model_path] [data_path] [noise_image_path] [sigma] [channel] [device_id] + #上述命令将完成推理所需的全部工作。执行完成后,将产生 preprocess.log、infer.log、psnr.log 三个日志文件。 + #如果您需要单独执行各部分代码,可以参照 run_infer_310.sh 内的流程分别进行编译、图片预处理、推理和 PSNR 计算,请注意核对各部分所需参数! + ``` + +## 模型描述 + +### 性能 + +#### 评估性能 + +BRDNet on “waterloo5050step40colorimage” + +| Parameters | BRDNet | +| -------------------------- | ------------------------------------------------------------------------------ | +| Resource | Ascend 910 ;CPU 2.60GHz,192cores; Memory, 755G | +| uploaded Date | 5/20/2021 (month/day/year) | +| MindSpore Version | master | +| Dataset | waterloo5050step40colorimage | +| Training Parameters | epoch=50, batch_size=32, lr=0.001 | +| Optimizer | Adam | +| Loss Function | MSELoss(reduction='sum') | +| outputs | denoised images | +| Loss | 80.839773 | +| Speed | 8p about 7000FPS to 7400FPS | +| Total time | 8p about 2h 14min | +| Checkpoint for Fine tuning | 8p: 13.68MB , 1p: 19.76MB (.ckpt file) | +| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/brdnet | + +## 随机情况说明 + +train.py中设置了随机种子。 + +## ModelZoo主页 + +请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。 diff --git a/model_zoo/official/cv/brdnet/ascend310_infer/CMakeLists.txt b/model_zoo/official/cv/brdnet/ascend310_infer/CMakeLists.txt new file mode 100644 index 00000000000..ee3c8544734 --- /dev/null +++ b/model_zoo/official/cv/brdnet/ascend310_infer/CMakeLists.txt @@ -0,0 +1,14 @@ +cmake_minimum_required(VERSION 3.14.1) +project(Ascend310Infer) +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined") +set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/) +option(MINDSPORE_PATH "mindspore install path" "") +include_directories(${MINDSPORE_PATH}) +include_directories(${MINDSPORE_PATH}/include) +include_directories(${PROJECT_SRC_ROOT}) +find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib) +file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*) + +add_executable(main src/main.cc src/utils.cc) +target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags) diff --git a/model_zoo/official/cv/brdnet/ascend310_infer/build.sh b/model_zoo/official/cv/brdnet/ascend310_infer/build.sh new file mode 100644 index 00000000000..285514e19f2 --- /dev/null +++ b/model_zoo/official/cv/brdnet/ascend310_infer/build.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ -d out ]; then + rm -rf out +fi + +mkdir out +cd out || exit + +if [ -f "Makefile" ]; then + make clean +fi + +cmake .. \ + -DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`" +make diff --git a/model_zoo/official/cv/brdnet/ascend310_infer/inc/utils.h b/model_zoo/official/cv/brdnet/ascend310_infer/inc/utils.h new file mode 100644 index 00000000000..abeb8fcbf11 --- /dev/null +++ b/model_zoo/official/cv/brdnet/ascend310_infer/inc/utils.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_INFERENCE_UTILS_H_ +#define MINDSPORE_INFERENCE_UTILS_H_ + +#include +#include +#include +#include +#include +#include "include/api/types.h" + +std::vector GetAllFiles(std::string_view dirName); +DIR *OpenDir(std::string_view dirName); +std::string RealPath(std::string_view path); +mindspore::MSTensor ReadFileToTensor(const std::string &file); +int WriteResult(const std::string& imageFile, const std::vector &outputs); +#endif diff --git a/model_zoo/official/cv/brdnet/ascend310_infer/src/main.cc b/model_zoo/official/cv/brdnet/ascend310_infer/src/main.cc new file mode 100644 index 00000000000..ec9a80a9fda --- /dev/null +++ b/model_zoo/official/cv/brdnet/ascend310_infer/src/main.cc @@ -0,0 +1,154 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "include/api/model.h" +#include "include/api/context.h" +#include "include/api/types.h" +#include "include/api/serialization.h" +#include "include/minddata/dataset/include/vision_ascend.h" +#include "include/minddata/dataset/include/execute.h" +#include "include/minddata/dataset/include/transforms.h" +#include "include/minddata/dataset/include/constants.h" +#include "include/minddata/dataset/include/vision.h" +#include "inc/utils.h" + +using mindspore::Serialization; +using mindspore::Model; +using mindspore::Context; +using mindspore::Status; +using mindspore::ModelType; +using mindspore::Graph; +using mindspore::GraphCell; +using mindspore::kSuccess; +using mindspore::MSTensor; +using mindspore::DataType; + +using mindspore::dataset::Execute; +using mindspore::dataset::TensorTransform; +using mindspore::dataset::vision::Decode; +using mindspore::dataset::vision::Resize; +using mindspore::dataset::vision::HWC2CHW; +using mindspore::dataset::vision::Normalize; +using mindspore::dataset::transforms::TypeCast; + +DEFINE_string(model_path, "", "model path"); +DEFINE_string(dataset_path, "", "dataset path"); +DEFINE_int32(input_width, 500, "input width"); +DEFINE_int32(input_height, 500, "inputheight"); +DEFINE_int32(device_id, 0, "device id"); + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (RealPath(FLAGS_model_path).empty()) { + std::cout << "Invalid mindir" << std::endl; + return 1; + } + + auto context = std::make_shared(); + auto ascend310_info = std::make_shared(); + ascend310_info->SetDeviceID(FLAGS_device_id); + context->MutableDeviceInfo().push_back(ascend310_info); + + Graph graph; + Status ret = Serialization::Load(FLAGS_model_path, ModelType::kMindIR, &graph); + if (ret != kSuccess) { + std::cout << "Load model failed." << std::endl; + return 1; + } + + Model model; + ret = model.Build(GraphCell(graph), context); + if (ret != kSuccess) { + std::cout << "ERROR: Build failed." << std::endl; + return 1; + } + + std::vector modelInputs = model.GetInputs(); + + auto all_files = GetAllFiles(FLAGS_dataset_path); + if (all_files.empty()) { + std::cout << "ERROR: no input data." << std::endl; + return 1; + } + + std::map costTime_map; + size_t size = all_files.size(); + // Define transform + + for (size_t i = 0; i < size; ++i) { + struct timeval start; + struct timeval end; + double startTime_ms; + double endTime_ms; + + std::vector inputs; + std::vector outputs; + + std::cout << "Start predict input files:" << all_files[i] << std::endl; + + auto image = ReadFileToTensor(all_files[i]); + + inputs.emplace_back(modelInputs[0].Name(), modelInputs[0].DataType(), modelInputs[0].Shape(), + image.Data().get(), image.DataSize()); + + gettimeofday(&start, NULL); + ret = model.Predict(inputs, &outputs); + gettimeofday(&end, NULL); + if (ret != kSuccess) { + std::cout << "Predict " << all_files[i] << " failed." << std::endl; + return 1; + } + startTime_ms = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000; + endTime_ms = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000; + costTime_map.insert(std::pair(startTime_ms, endTime_ms)); + WriteResult(all_files[i], outputs); + } + + double average = 0.0; + int infer_cnt = 0; + + for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) { + double diff = 0.0; + diff = iter->second - iter->first; + average += diff; + infer_cnt++; + } + + average = average / infer_cnt; + std::stringstream timeCost; + timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << infer_cnt << std::endl; + std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << infer_cnt << std::endl; + std::string file_name = "./time_Result" + std::string("/test_perform_static.txt"); + std::ofstream file_stream(file_name.c_str(), std::ios::trunc); + file_stream << timeCost.str(); + file_stream.close(); + costTime_map.clear(); + + return 0; +} diff --git a/model_zoo/official/cv/brdnet/ascend310_infer/src/utils.cc b/model_zoo/official/cv/brdnet/ascend310_infer/src/utils.cc new file mode 100644 index 00000000000..4557e1c3455 --- /dev/null +++ b/model_zoo/official/cv/brdnet/ascend310_infer/src/utils.cc @@ -0,0 +1,127 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "inc/utils.h" + +#include +#include +#include + +using mindspore::MSTensor; +using mindspore::DataType; + +std::vector GetAllFiles(std::string_view dirName) { + struct dirent *filename; + DIR *dir = OpenDir(dirName); + if (dir == nullptr) { + return {}; + } + std::vector res; + while ((filename = readdir(dir)) != nullptr) { + std::string dName = std::string(filename->d_name); + if (dName == "." || dName == ".." || filename->d_type != DT_REG) { + continue; + } + res.emplace_back(std::string(dirName) + "/" + filename->d_name); + } + std::sort(res.begin(), res.end()); + for (auto &f : res) { + std::cout << "image file: " << f << std::endl; + } + return res; +} + +int WriteResult(const std::string& imageFile, const std::vector &outputs) { + std::string homePath = "./result_Files"; + for (size_t i = 0; i < outputs.size(); ++i) { + size_t outputSize; + std::shared_ptr netOutput = outputs[i].Data(); + outputSize = outputs[i].DataSize(); + int pos = imageFile.rfind('/'); + std::string fileName(imageFile, pos + 1); + fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin"); + std::string outFileName = homePath + "/" + fileName; + FILE *outputFile = fopen(outFileName.c_str(), "wb"); + fwrite(netOutput.get(), outputSize, sizeof(char), outputFile); + fclose(outputFile); + outputFile = nullptr; + } + return 0; +} + +mindspore::MSTensor ReadFileToTensor(const std::string &file) { + if (file.empty()) { + std::cout << "Pointer file is nullptr" << std::endl; + return mindspore::MSTensor(); + } + + std::ifstream ifs(file); + if (!ifs.good()) { + std::cout << "File: " << file << " is not exist" << std::endl; + return mindspore::MSTensor(); + } + + if (!ifs.is_open()) { + std::cout << "File: " << file << "open failed" << std::endl; + return mindspore::MSTensor(); + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast(size)}, nullptr, size); + + ifs.seekg(0, std::ios::beg); + ifs.read(reinterpret_cast(buffer.MutableData()), size); + ifs.close(); + + return buffer; +} + + +DIR *OpenDir(std::string_view dirName) { + if (dirName.empty()) { + std::cout << " dirName is null ! " << std::endl; + return nullptr; + } + std::string realPath = RealPath(dirName); + struct stat s; + lstat(realPath.c_str(), &s); + if (!S_ISDIR(s.st_mode)) { + std::cout << "dirName is not a valid directory !" << std::endl; + return nullptr; + } + DIR *dir = opendir(realPath.c_str()); + if (dir == nullptr) { + std::cout << "Can not open dir " << dirName << std::endl; + return nullptr; + } + std::cout << "Successfully opened the dir " << dirName << std::endl; + return dir; +} + +std::string RealPath(std::string_view path) { + char realPathMem[PATH_MAX] = {0}; + char *realPathRet = nullptr; + realPathRet = realpath(path.data(), realPathMem); + if (realPathRet == nullptr) { + std::cout << "File: " << path << " is not exist."; + return ""; + } + + std::string realPath(realPathMem); + std::cout << path << " realpath is: " << realPath << std::endl; + return realPath; +} diff --git a/model_zoo/official/cv/brdnet/cal_psnr.py b/model_zoo/official/cv/brdnet/cal_psnr.py new file mode 100644 index 00000000000..5344d7ad210 --- /dev/null +++ b/model_zoo/official/cv/brdnet/cal_psnr.py @@ -0,0 +1,73 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import argparse +import os +import time +import glob +import math +import numpy as np +import PIL.Image as Image + +## Params +parser = argparse.ArgumentParser() +parser.add_argument('--image_path', default='./Test/Kodak24/' + , type=str, help='directory of test dataset') +parser.add_argument('--output_path', default='', type=str, + help='path of the predict files that generated by the model') +parser.add_argument('--image_width', default=500, type=int, help='image_width') +parser.add_argument('--image_height', default=500, type=int, help='image_height') +parser.add_argument('--channel', default=3, type=int + , help='image channel, 3 for color, 1 for gray') + +args = parser.parse_args() + +def calculate_psnr(image1, image2): + image1 = np.float64(image1) + image2 = np.float64(image2) + diff = image1 - image2 + diff = diff.flatten('C') + rmse = math.sqrt(np.mean(diff**2.)) + return 20*math.log10(1.0/rmse) + +def cal_psnr(output_path, image_path): + file_list = glob.glob(image_path+'*') # image_path must end by '/' + psnr = [] #after denoise + + start_time = time.time() + for file in file_list: + filename = file.split('/')[-1].split('.')[0] # get the name of image file + # read image + if args.channel == 3: + img_clean = np.array(Image.open(file), dtype='float32') / 255.0 + else: + img_clean = np.expand_dims(np.array(Image.open(file).convert('L'), dtype='float32') / 255.0, axis=2) + + result_file = os.path.join(output_path, filename+"_noise_0.bin") + y_predict = np.fromfile(result_file, dtype=np.float32) + y_predict = y_predict.reshape(args.channel, args.image_height, args.image_width) + img_out = y_predict.transpose((1, 2, 0))#HWC + img_out = np.clip(img_out, 0, 1) + psnr_denoised = calculate_psnr(img_clean, img_out) + psnr.append(psnr_denoised) + print(filename, ": psnr_denoised: ", " ", psnr_denoised) + + psnr_avg = sum(psnr)/len(psnr) + print("Average PSNR:", psnr_avg) + print("Testing finished....") + time_used = time.time() - start_time + print("Time cost:"+str(time_used)+" seconds!") + +if __name__ == '__main__': + cal_psnr(args.output_path, args.image_path) diff --git a/model_zoo/official/cv/brdnet/eval.py b/model_zoo/official/cv/brdnet/eval.py new file mode 100644 index 00000000000..253e1c1aad8 --- /dev/null +++ b/model_zoo/official/cv/brdnet/eval.py @@ -0,0 +1,193 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import datetime +import argparse +import os +import time +import glob +import pandas as pd +import numpy as np +import PIL.Image as Image + +import mindspore +import mindspore.nn as nn +from mindspore import context +from mindspore.common import set_seed +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore import load_checkpoint, load_param_into_net + +from src.logger import get_logger +from src.models import BRDNet + +## Params +parser = argparse.ArgumentParser() + +parser.add_argument('--test_dir', default='./Test/Kodak24/' + , type=str, help='directory of test dataset') +parser.add_argument('--sigma', default=15, type=int, help='noise level') +parser.add_argument('--channel', default=3, type=int + , help='image channel, 3 for color, 1 for gray') +parser.add_argument('--pretrain_path', default=None, type=str, help='path of pre-trained model') +parser.add_argument('--ckpt_name', default=None, type=str, help='ckpt_name') +parser.add_argument('--use_modelarts', type=int, default=0 + , help='1 for True, 0 for False;when set True, we should load dataset from obs with moxing') +parser.add_argument('--train_url', type=str, default='train_url/' + , help='needed by modelarts, but we donot use it because the name is ambiguous') +parser.add_argument('--data_url', type=str, default='data_url/' + , help='needed by modelarts, but we donot use it because the name is ambiguous') +parser.add_argument('--output_path', type=str, default='./output/' + , help='output_path,when use_modelarts is set True, it will be cache/output/') +parser.add_argument('--outer_path', type=str, default='s3://output/' + , help='obs path,to store e.g ckpt files ') + +parser.add_argument('--device_target', type=str, default='Ascend', + help='device where the code will be implemented. (Default: Ascend)') + +set_seed(1) + +args = parser.parse_args() +save_dir = os.path.join(args.output_path, 'sigma_' + str(args.sigma) + \ + '_' + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + +if not args.use_modelarts and not os.path.exists(save_dir): + os.makedirs(save_dir) + +def test(model_path): + args.logger.info('Start to test on {}'.format(args.test_dir)) + out_dir = os.path.join(save_dir, args.test_dir.split('/')[-2]) # args.test_dir must end by '/' + if not args.use_modelarts and not os.path.exists(out_dir): + os.makedirs(out_dir) + + model = BRDNet(args.channel) + args.logger.info('load test weights from '+str(model_path)) + load_param_into_net(model, load_checkpoint(model_path)) + name = [] + psnr = [] #after denoise + ssim = [] #after denoise + psnr_b = [] #before denoise + ssim_b = [] #before denoise + + if args.use_modelarts: + args.logger.info("copying test dataset from obs to cache....") + mox.file.copy_parallel(args.test_dir, 'cache/test') + args.logger.info("copying test dataset finished....") + args.test_dir = 'cache/test/' + + file_list = glob.glob(args.test_dir+'*') # args.test_dir must end by '/' + model.set_train(False) + + cast = P.Cast() + transpose = P.Transpose() + expand_dims = P.ExpandDims() + compare_psnr = nn.PSNR() + compare_ssim = nn.SSIM() + + args.logger.info("start testing....") + start_time = time.time() + for file in file_list: + suffix = file.split('.')[-1] + # read image + if args.channel == 3: + img_clean = np.array(Image.open(file), dtype='float32') / 255.0 + else: + img_clean = np.expand_dims(np.array(Image.open(file).convert('L'), dtype='float32') / 255.0, axis=2) + + np.random.seed(0) #obtain the same random data when it is in the test phase + img_test = img_clean + np.random.normal(0, args.sigma/255.0, img_clean.shape) + + img_clean = Tensor(img_clean, mindspore.float32) #HWC + img_test = Tensor(img_test, mindspore.float32) #HWC + + # predict + img_clean = expand_dims(transpose(img_clean, (2, 0, 1)), 0)#NCHW + img_test = expand_dims(transpose(img_test, (2, 0, 1)), 0)#NCHW + + y_predict = model(img_test) #NCHW + + # calculate numeric metrics + + img_out = C.clip_by_value(y_predict, 0, 1) + + psnr_noise, psnr_denoised = compare_psnr(img_clean, img_test), compare_psnr(img_clean, img_out) + ssim_noise, ssim_denoised = compare_ssim(img_clean, img_test), compare_ssim(img_clean, img_out) + + + psnr.append(psnr_denoised.asnumpy()[0]) + ssim.append(ssim_denoised.asnumpy()[0]) + psnr_b.append(psnr_noise.asnumpy()[0]) + ssim_b.append(ssim_noise.asnumpy()[0]) + + # save images + filename = file.split('/')[-1].split('.')[0] # get the name of image file + name.append(filename) + + if not args.use_modelarts: + # inner the operation 'Image.save', it will first check the file \ + # existence of same name, that is not allowed on modelarts + img_test = cast(img_test*255, mindspore.uint8).asnumpy() + img_test = img_test.squeeze(0).transpose((1, 2, 0)) #turn into HWC to save as an image + img_test = Image.fromarray(img_test) + img_test.save(os.path.join(out_dir, filename+'_sigma'+'{}_psnr{:.2f}.'\ + .format(args.sigma, psnr_noise.asnumpy()[0])+str(suffix))) + img_out = cast(img_out*255, mindspore.uint8).asnumpy() + img_out = img_out.squeeze(0).transpose((1, 2, 0)) #turn into HWC to save as an image + img_out = Image.fromarray(img_out) + img_out.save(os.path.join(out_dir, filename+'_psnr{:.2f}.'.format(psnr_denoised.asnumpy()[0])+str(suffix))) + + + psnr_avg = sum(psnr)/len(psnr) + ssim_avg = sum(ssim)/len(ssim) + psnr_avg_b = sum(psnr_b)/len(psnr_b) + ssim_avg_b = sum(ssim_b)/len(ssim_b) + name.append('Average') + psnr.append(psnr_avg) + ssim.append(ssim_avg) + psnr_b.append(psnr_avg_b) + ssim_b.append(ssim_avg_b) + args.logger.info('Before denoise: Average PSNR_b = {0:.2f}, \ + SSIM_b = {1:.2f};After denoise: Average PSNR = {2:.2f}, SSIM = {3:.2f}'\ + .format(psnr_avg_b, ssim_avg_b, psnr_avg, ssim_avg)) + args.logger.info("testing finished....") + time_used = time.time() - start_time + args.logger.info("time cost:"+str(time_used)+" seconds!") + if not args.use_modelarts: + pd.DataFrame({'name': np.array(name), 'psnr_b': np.array(psnr_b), \ + 'psnr': np.array(psnr), 'ssim_b': np.array(ssim_b), \ + 'ssim': np.array(ssim)}).to_csv(out_dir+'/metrics.csv', index=True) + +if __name__ == '__main__': + + device_id = int(os.getenv('DEVICE_ID', '0')) + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=args.device_target, device_id=device_id, save_graphs=False) + + args.logger = get_logger(save_dir, "BRDNet", 0) + args.logger.save_args(args) + + if args.use_modelarts: + import moxing as mox + args.logger.info("copying test weights from obs to cache....") + mox.file.copy_parallel(args.pretrain_path, 'cache/weight') + args.logger.info("copying test weights finished....") + args.pretrain_path = 'cache/weight/' + + test(os.path.join(args.pretrain_path, args.ckpt_name)) + + if args.use_modelarts: + args.logger.info("copying files from cache to obs....") + mox.file.copy_parallel(save_dir, args.outer_path) + args.logger.info("copying finished....") diff --git a/model_zoo/official/cv/brdnet/export.py b/model_zoo/official/cv/brdnet/export.py new file mode 100644 index 00000000000..1bb99788bd7 --- /dev/null +++ b/model_zoo/official/cv/brdnet/export.py @@ -0,0 +1,56 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +##############export checkpoint file into air, onnx, mindir models################# +python export.py +""" +import argparse +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export + +from src.models import BRDNet + + +## Params +parser = argparse.ArgumentParser() + +parser.add_argument('--batch_size', default=32, type=int, help='batch size') +parser.add_argument('--channel', default=3, type=int + , help='image channel, 3 for color, 1 for gray') +parser.add_argument("--image_height", type=int, default=50, help="Image height.") +parser.add_argument("--image_width", type=int, default=50, help="Image width.") +parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") +parser.add_argument("--file_name", type=str, default="brdnet", help="output file name.") +parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") +parser.add_argument('--device_target', type=str, default='Ascend' + , help='device where the code will be implemented. (Default: Ascend)') +parser.add_argument("--device_id", type=int, default=0, help="Device id") + +args_opt = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + +if __name__ == '__main__': + + net = BRDNet(args_opt.channel) + + param_dict = load_checkpoint(args_opt.ckpt_file) + load_param_into_net(net, param_dict) + + input_arr = Tensor(np.zeros([args_opt.batch_size, args_opt.channel, \ + args_opt.image_height, args_opt.image_width]), ms.float32) + export(net, input_arr, file_name=args_opt.file_name, file_format=args_opt.file_format) diff --git a/model_zoo/official/cv/brdnet/preprocess.py b/model_zoo/official/cv/brdnet/preprocess.py new file mode 100644 index 00000000000..7915fd45d38 --- /dev/null +++ b/model_zoo/official/cv/brdnet/preprocess.py @@ -0,0 +1,53 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import argparse +import os +import glob +import numpy as np +import PIL.Image as Image + +parser = argparse.ArgumentParser() +parser.add_argument('--out_dir', type=str, required=True, + help='directory to store the image with noise') +parser.add_argument('--image_path', type=str, required=True, + help='directory of image to add noise') +parser.add_argument('--channel', type=int, default=3 + , help='image channel, 3 for color, 1 for gray') +parser.add_argument('--sigma', type=int, default=15, help='level of noise') +args = parser.parse_args() + +def add_noise(out_dir, image_path, channel, sigma): + file_list = glob.glob(image_path+'*') # image_path must end by '/' + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + for file in file_list: + print("Adding noise to: ", file) + # read image + if channel == 3: + img_clean = np.array(Image.open(file), dtype='float32') / 255.0 + else: + img_clean = np.expand_dims(np.array(Image.open(file).convert('L'), dtype='float32') / 255.0, axis=2) + + np.random.seed(0) #obtain the same random data when it is in the test phase + img_test = img_clean + np.random.normal(0, sigma/255.0, img_clean.shape).astype(np.float32)#HWC + img_test = np.expand_dims(img_test.transpose((2, 0, 1)), 0)#NCHW + #img_test = np.clip(img_test, 0, 1) + + filename = file.split('/')[-1].split('.')[0] # get the name of image file + img_test.tofile(os.path.join(out_dir, filename+'_noise.bin')) + +if __name__ == "__main__": + add_noise(args.out_dir, args.image_path, args.channel, args.sigma) diff --git a/model_zoo/official/cv/brdnet/scripts/run_distribute_train.sh b/model_zoo/official/cv/brdnet/scripts/run_distribute_train.sh new file mode 100644 index 00000000000..c220bf8aaa8 --- /dev/null +++ b/model_zoo/official/cv/brdnet/scripts/run_distribute_train.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 8 ]; then + echo "Usage: sh run_distribute_train.sh [train_code_path] [train_data]" \ + "[batch_size] [sigma] [channel] [epoch] [lr] [rank_table_file_path]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)/" + fi +} + +get_real_path_name() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +train_code_path=$(get_real_path $1) +echo $train_code_path + +if [ ! -d $train_code_path ] +then + echo "error: train_code_path=$train_code_path is not a dictionary." +exit 1 +fi + +train_data=$(get_real_path $2) +echo $train_data + +if [ ! -d $train_data ] +then + echo "error: train_data=$train_data is not a dictionary." +exit 1 +fi + +rank_table_file_path=$(get_real_path_name $8) +echo "rank_table_file_path: "$rank_table_file_path + +ulimit -c unlimited +export SLOG_PRINT_TO_STDOUT=0 +export RANK_TABLE_FILE=$rank_table_file_path +export RANK_SIZE=8 +export RANK_START_ID=0 + +for((i=0;i<=$RANK_SIZE-1;i++)); +do + export RANK_ID=${i} + export DEVICE_ID=$((i + RANK_START_ID)) + echo 'start rank='${i}', device id='${DEVICE_ID}'...' + if [ -d ${train_code_path}device${DEVICE_ID} ]; then + rm -rf ${train_code_path}device${DEVICE_ID} + fi + mkdir ${train_code_path}device${DEVICE_ID} + cd ${train_code_path}device${DEVICE_ID} || exit + nohup python ${train_code_path}train.py --is_distributed=1 \ + --train_data=${train_data} \ + --batch_size=$3 \ + --sigma=$4 \ + --channel=$5 \ + --epoch=$6 \ + --lr=$7 > ${train_code_path}device${DEVICE_ID}/log.txt 2>&1 & +done diff --git a/model_zoo/official/cv/brdnet/scripts/run_eval.sh b/model_zoo/official/cv/brdnet/scripts/run_eval.sh new file mode 100644 index 00000000000..0dff29bed4f --- /dev/null +++ b/model_zoo/official/cv/brdnet/scripts/run_eval.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 6 ]; then + echo "Usage: sh run_eval.sh [train_code_path] [test_dir] [sigma]" \ + "[channel] [pretrain_path] [ckpt_name]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)/" + fi +} + +train_code_path=$(get_real_path $1) +echo "train_code_path: "$train_code_path + +test_dir=$(get_real_path $2) +echo "test_dir: "$test_dir + +pretrain_path=$(get_real_path $5) +echo "pretrain_path: "$pretrain_path + +if [ ! -d $train_code_path ] +then + echo "error: train_code_path=$train_code_path is not a dictionary." +exit 1 +fi + +if [ ! -d $test_dir ] +then + echo "error: test_dir=$test_dir is not a dictionary." +exit 1 +fi + +if [ ! -d $pretrain_path ] +then + echo "error: pretrain_path=$pretrain_path is not a dictionary." +exit 1 +fi + +python ${train_code_path}eval.py \ +--test_dir=$test_dir \ +--sigma=$3 \ +--channel=$4 \ +--pretrain_path=$pretrain_path \ +--ckpt_name=$6 diff --git a/model_zoo/official/cv/brdnet/scripts/run_infer_310.sh b/model_zoo/official/cv/brdnet/scripts/run_infer_310.sh new file mode 100644 index 00000000000..0bd9996577f --- /dev/null +++ b/model_zoo/official/cv/brdnet/scripts/run_infer_310.sh @@ -0,0 +1,131 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 6 ]; then + echo "Usage: sh run_infer_310.sh [model_path] [data_path]" \ + "[noise_image_path] [sigma] [channel] [device_id]" +exit 1 +fi + +get_real_path_name() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)/" + fi +} + +model=$(get_real_path_name $1) +data_path=$(get_real_path $2) +noise_image_path=$(get_real_path $3) +sigma=$4 +channel=$5 +device_id=$6 + +echo "model path: "$model +echo "dataset path: "$data_path +echo "noise image path: "$noise_image_path +echo "sigma: "$sigma +echo "channel: "$channel +echo "device id: "$device_id + +export ASCEND_HOME=/usr/local/Ascend/ +if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then + export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH + export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH + export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe + export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp +else + export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH + export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH + export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=$ASCEND_HOME/opp +fi + +function compile_app() +{ + cd ../ascend310_infer/ || exit + if [ -f "Makefile" ]; then + make clean + fi + sh build.sh &> build.log +} + +function preprocess() +{ + cd ../ || exit + echo "\nstart preprocess" + echo "waitting for preprocess finish..." + python3.7 preprocess.py --out_dir=$noise_image_path --image_path=$data_path --channel=$channel --sigma=$sigma > preprocess.log 2>&1 + echo "preprocess finished! you can see the log in preprocess.log!" +} + +function infer() +{ + cd ./scripts || exit + if [ -d result_Files ]; then + rm -rf ./result_Files + fi + if [ -d time_Result ]; then + rm -rf ./time_Result + fi + mkdir result_Files + mkdir time_Result + echo "\nstart infer..." + echo "waitting for infer finish..." + ../ascend310_infer/out/main --model_path=$model --dataset_path=$noise_image_path --device_id=$device_id > infer.log 2>&1 + echo "infer finished! you can see the log in infer.log!" +} + +function cal_psnr() +{ + echo "\nstart calculate PSNR..." + echo "waitting for calculate finish..." + python3.7 ../cal_psnr.py --image_path=$data_path --output_path=./result_Files --channel=$channel > psnr.log 2>&1 + echo "calculate finished! you can see the log in psnr.log\n" +} + +compile_app +if [ $? -ne 0 ]; then + echo "compile app code failed" + exit 1 +fi + +preprocess +if [ $? -ne 0 ]; then + echo "execute preprocess failed" + exit 1 +fi + +infer +if [ $? -ne 0 ]; then + echo " execute inference failed" + exit 1 +fi +cal_psnr +if [ $? -ne 0 ]; then + echo "calculate psnr failed" + exit 1 +fi diff --git a/model_zoo/official/cv/brdnet/scripts/run_train.sh b/model_zoo/official/cv/brdnet/scripts/run_train.sh new file mode 100644 index 00000000000..bd7bf64f499 --- /dev/null +++ b/model_zoo/official/cv/brdnet/scripts/run_train.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 7 ]; then + echo "Usage: sh run_train.sh [train_code_path] [train_data]" \ + "[batch_size] [sigma] [channel] [epoch] [lr]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)/" + fi +} + +train_code_path=$(get_real_path $1) +echo "train_code_path: "$train_code_path + +train_data=$(get_real_path $2) +echo "train_data: "$train_data + +if [ ! -d $train_code_path ] +then + echo "error: train_code_path=$train_code_path is not a dictionary." +exit 1 +fi + +if [ ! -d $train_data ] +then + echo "error: train_data=$train_data is not a dictionary." +exit 1 +fi + +nohup python ${train_code_path}train.py --is_distributed=0 \ +--train_data=${train_data} \ +--batch_size=$3 \ +--sigma=$4 \ +--channel=$5 \ +--epoch=$6 \ +--lr=$7 > log.txt 2>&1 & +echo 'Train task has been started successfully!' +echo 'Please check the log at log.txt' diff --git a/model_zoo/official/cv/brdnet/src/dataset.py b/model_zoo/official/cv/brdnet/src/dataset.py new file mode 100644 index 00000000000..1abd533a64e --- /dev/null +++ b/model_zoo/official/cv/brdnet/src/dataset.py @@ -0,0 +1,69 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import glob +import numpy as np +import PIL.Image as Image +import mindspore.dataset as ds +import mindspore.dataset.vision.c_transforms as CV +from src.distributed_sampler import DistributedSampler + +class BRDNetDataset: + """ BRDNetDataset. + Args: + data_path: path of images, must end by '/' + sigma: noise level + channel: 3 for color, 1 for gray + """ + def __init__(self, data_path, sigma, channel): + images = [] + file_dictory = glob.glob(data_path+'*.bmp') #notice the data format + for file in file_dictory: + images.append(file) + self.images = images + self.sigma = sigma + self.channel = channel + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + get_batch_x: image with noise + get_batch_y: image without noise + """ + if self.channel == 3: + get_batch_y = np.array(Image.open(self.images[index]), dtype='uint8') + else: + get_batch_y = np.expand_dims(np.array(Image.open(self.images[index]).convert('L'), dtype='uint8'), axis=2) + + get_batch_y = get_batch_y.astype('float32')/255.0 + noise = np.random.normal(0, self.sigma/255.0, get_batch_y.shape).astype('float32') # noise + get_batch_x = get_batch_y + noise # input image = clean image + noise + return get_batch_x, get_batch_y + + def __len__(self): + return len(self.images) + +def create_BRDNetDataset(data_path, sigma, channel, batch_size, device_num, rank, shuffle): + + dataset = BRDNetDataset(data_path, sigma, channel) + dataset_len = len(dataset) + distributed_sampler = DistributedSampler(dataset_len, device_num, rank, shuffle=shuffle) + hwc_to_chw = CV.HWC2CHW() + data_set = ds.GeneratorDataset(dataset, column_names=["image", "label"], \ + shuffle=shuffle, sampler=distributed_sampler) + data_set = data_set.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8) + data_set = data_set.map(input_columns=["label"], operations=hwc_to_chw, num_parallel_workers=8) + data_set = data_set.batch(batch_size, drop_remainder=True) + return data_set, dataset_len diff --git a/model_zoo/official/cv/brdnet/src/distributed_sampler.py b/model_zoo/official/cv/brdnet/src/distributed_sampler.py new file mode 100644 index 00000000000..255d8831fd1 --- /dev/null +++ b/model_zoo/official/cv/brdnet/src/distributed_sampler.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""BRDNetDataset distributed sampler.""" +from __future__ import division +import math +import numpy as np + + +class DistributedSampler: + """Distributed sampler.""" + def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + print("***********Setting world_size to 1 since it is not passed in ******************") + num_replicas = 1 + if rank is None: + print("***********Setting rank to 0 since it is not passed in ******************") + rank = 0 + self.dataset_size = dataset_size + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + + def __iter__(self): + # deterministically shuffle based on epoch + if self.shuffle: + indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size) + # np.array type. number from 0 to len(dataset_size)-1, used as index of dataset + indices = indices.tolist() + self.epoch += 1 + # change to list type + else: + indices = list(range(self.dataset_size)) + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples diff --git a/model_zoo/official/cv/brdnet/src/logger.py b/model_zoo/official/cv/brdnet/src/logger.py new file mode 100644 index 00000000000..d4edb379618 --- /dev/null +++ b/model_zoo/official/cv/brdnet/src/logger.py @@ -0,0 +1,80 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Custom Logger.""" +import os +import sys +import logging +from datetime import datetime + + +class LOGGER(logging.Logger): + """ + Logger. + + Args: + logger_name: String. Logger name. + rank: Integer. Rank id. + """ + def __init__(self, logger_name, rank=0): + super(LOGGER, self).__init__(logger_name) + self.rank = rank + if rank % 8 == 0: + console = logging.StreamHandler(sys.stdout) + console.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + console.setFormatter(formatter) + self.addHandler(console) + + def setup_logging_file(self, log_dir, rank=0): + """Setup logging file.""" + self.rank = rank + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank) + self.log_fn = os.path.join(log_dir, log_name) + fh = logging.FileHandler(self.log_fn) + fh.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + fh.setFormatter(formatter) + self.addHandler(fh) + + def info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO): + self._log(logging.INFO, msg, args, **kwargs) + + def save_args(self, args): + self.info('Args:') + args_dict = vars(args) + for key in args_dict.keys(): + self.info('--> %s: %s', key, args_dict[key]) + self.info('') + + def important_info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO) and self.rank == 0: + line_width = 2 + important_msg = '\n' + important_msg += ('*'*70 + '\n')*line_width + important_msg += ('*'*line_width + '\n')*2 + important_msg += '*'*line_width + ' '*8 + msg + '\n' + important_msg += ('*'*line_width + '\n')*2 + important_msg += ('*'*70 + '\n')*line_width + self.info(important_msg, *args, **kwargs) + + +def get_logger(path, logger_name, rank): + """Get Logger.""" + logger = LOGGER(logger_name, rank) + logger.setup_logging_file(path, rank) + return logger diff --git a/model_zoo/official/cv/brdnet/src/models.py b/model_zoo/official/cv/brdnet/src/models.py new file mode 100644 index 00000000000..8b0cd4f226a --- /dev/null +++ b/model_zoo/official/cv/brdnet/src/models.py @@ -0,0 +1,150 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore.nn as nn +import mindspore.ops as ops +#from batch_renorm import BatchRenormalization + +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore import context +from mindspore.context import ParallelMode +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.communication.management import get_group_size +from mindspore.ops import functional as F + +class BRDNet(nn.Cell): + """ + args: + channel: 3 for color, 1 for gray + """ + def __init__(self, channel): + super(BRDNet, self).__init__() + + self.Conv2d_1 = nn.Conv2d(channel, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True) + self.BRN_1 = nn.BatchNorm2d(64, eps=1e-3) + self.layer1 = self.make_layer1(15) + self.Conv2d_2 = nn.Conv2d(64, channel, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True) + self.Conv2d_3 = nn.Conv2d(channel, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True) + self.BRN_2 = nn.BatchNorm2d(64, eps=1e-3) + self.layer2 = self.make_layer2(7) + self.Conv2d_4 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True) + self.BRN_3 = nn.BatchNorm2d(64, eps=1e-3) + self.layer3 = self.make_layer2(6) + self.Conv2d_5 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True) + self.BRN_4 = nn.BatchNorm2d(64, eps=1e-3) + self.Conv2d_6 = nn.Conv2d(64, channel, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True) + self.Conv2d_7 = nn.Conv2d(channel*2, channel, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True) + + self.relu = nn.ReLU() + self.sub = ops.Sub() + self.concat = ops.Concat(axis=1)#NCHW + + def make_layer1(self, nums): + layers = [] + assert nums > 0 + for _ in range(nums): + layers.append(nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), pad_mode='same', has_bias=True)) + layers.append(nn.BatchNorm2d(64, eps=1e-3)) + layers.append(nn.ReLU()) + return nn.SequentialCell(layers) + + def make_layer2(self, nums): + layers = [] + assert nums > 0 + for _ in range(nums): + layers.append(nn.Conv2d(64, 64, kernel_size=(3, 3), \ + stride=(1, 1), dilation=(2, 2), pad_mode='same', has_bias=True)) + layers.append(nn.ReLU()) + return nn.SequentialCell(layers) + + def construct(self, inpt): + #inpt-----> 'NCHW' + x = self.Conv2d_1(inpt) + x = self.BRN_1(x) + x = self.relu(x) + # 15 layers, Conv+BN+relu + x = self.layer1(x) + + # last layer, Conv + x = self.Conv2d_2(x) #for output channel, gray is 1 color is 3 + x = self.sub(inpt, x) # input - noise + + y = self.Conv2d_3(inpt) + y = self.BRN_2(y) + y = self.relu(y) + + # first Conv+relu's + y = self.layer2(y) + + y = self.Conv2d_4(y) + y = self.BRN_3(y) + y = self.relu(y) + + # second Conv+relu's + y = self.layer3(y) + + y = self.Conv2d_5(y) + y = self.BRN_4(y) + y = self.relu(y) + + y = self.Conv2d_6(y)#for output channel, gray is 1 color is 3 + y = self.sub(inpt, y) # input - noise + + o = self.concat((x, y)) + z = self.Conv2d_7(o)#gray is 1 color is 3 + z = self.sub(inpt, z) + + return z + +class BRDWithLossCell(nn.Cell): + def __init__(self, network): + super(BRDWithLossCell, self).__init__() + self.network = network + self.loss = nn.MSELoss(reduction='sum') #we use 'sum' instead of 'mean' to avoid the loss becoming too small + def construct(self, images, targets): + output = self.network(images) + return self.loss(output, targets) + +class TrainingWrapper(nn.Cell): + """Training wrapper.""" + def __init__(self, network, optimizer, sens=1.0): + super(TrainingWrapper, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.grad_reducer = None + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + if self.reducer_flag: + mean = context.get_auto_parallel_context("gradients_mean") + if auto_parallel_context().get_device_num_is_set(): + degree = context.get_auto_parallel_context("device_num") + else: + degree = get_group_size() + self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) + + def construct(self, *args): + weights = self.weights + loss = self.network(*args) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(*args, sens) + if self.reducer_flag: + grads = self.grad_reducer(grads) + return F.depend(loss, self.optimizer(grads)) diff --git a/model_zoo/official/cv/brdnet/train.py b/model_zoo/official/cv/brdnet/train.py new file mode 100644 index 00000000000..b83fec8deef --- /dev/null +++ b/model_zoo/official/cv/brdnet/train.py @@ -0,0 +1,182 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import datetime +import argparse + +import mindspore +import mindspore.nn as nn +from mindspore import context +from mindspore.train import Model +from mindspore.common import set_seed +from mindspore.context import ParallelMode +from mindspore.common.tensor import Tensor +from mindspore.train.callback import TimeMonitor, LossMonitor +from mindspore import load_checkpoint, load_param_into_net +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint +from mindspore.communication.management import init, get_rank, get_group_size + +from src.logger import get_logger +from src.dataset import create_BRDNetDataset +from src.models import BRDNet, BRDWithLossCell, TrainingWrapper + + +## Params +parser = argparse.ArgumentParser() + +parser.add_argument('--batch_size', default=32, type=int, help='batch size') +parser.add_argument('--train_data', default='../dataset/waterloo5050step40colorimage/' + , type=str, help='path of train data') +parser.add_argument('--sigma', default=75, type=int, help='noise level') +parser.add_argument('--channel', default=3, type=int + , help='image channel, 3 for color, 1 for gray') +parser.add_argument('--epoch', default=50, type=int, help='number of train epoches') +parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate for Adam') +parser.add_argument('--save_every', default=1, type=int, help='save model at every x epoches') +parser.add_argument('--pretrain', default=None, type=str, help='path of pre-trained model') +parser.add_argument('--use_modelarts', type=int, default=0 + , help='1 for True, 0 for False; when set True, we should load dataset from obs with moxing') +parser.add_argument('--train_url', type=str, default='train_url/' + , help='needed by modelarts, but we donot use it because the name is ambiguous') +parser.add_argument('--data_url', type=str, default='data_url/' + , help='needed by modelarts, but we donot use it because the name is ambiguous') +parser.add_argument('--output_path', type=str, default='./output/' + , help='output_path,when use_modelarts is set True, it will be cache/output/') +parser.add_argument('--outer_path', type=str, default='s3://output/' + , help='obs path,to store e.g ckpt files ') + +parser.add_argument('--device_target', type=str, default='Ascend' + , help='device where the code will be implemented. (Default: Ascend)') +parser.add_argument('--is_distributed', type=int, default=0, help='if multi device') +parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') +parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') +parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank') +parser.add_argument('--ckpt_save_max', type=int, default=5 + , help='Maximum number of checkpoint files can be saved. Default: 5.') + +set_seed(1) +args = parser.parse_args() +save_dir = os.path.join(args.output_path, 'sigma_' + str(args.sigma) \ + + '_' + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + +if not args.use_modelarts and not os.path.exists(save_dir): + os.makedirs(save_dir) + +def get_lr(steps_per_epoch, max_epoch, init_lr): + lr_each_step = [] + while max_epoch > 0: + tem = min(30, max_epoch) + for _ in range(steps_per_epoch*tem): + lr_each_step.append(init_lr) + max_epoch -= tem + init_lr /= 10 + return lr_each_step + + +device_id = int(os.getenv('DEVICE_ID', '0')) +context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=args.device_target, save_graphs=False) + +def train(): + + if args.is_distributed: + assert args.device_target == "Ascend" + init() + context.set_context(device_id=device_id) + args.rank = get_rank() + args.group_size = get_group_size() + device_num = args.group_size + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL) + else: + if args.device_target == "Ascend": + context.set_context(device_id=device_id) + + # select for master rank save ckpt or all rank save, compatible for model parallel + args.rank_save_ckpt_flag = 0 + if args.is_save_on_master: + if args.rank == 0: + args.rank_save_ckpt_flag = 1 + else: + args.rank_save_ckpt_flag = 1 + + + args.logger = get_logger(save_dir, "BRDNet", args.rank) + args.logger.save_args(args) + + if args.use_modelarts: + import moxing as mox + args.logger.info("copying train data from obs to cache....") + mox.file.copy_parallel(args.train_data, 'cache/dataset') # the args.train_data must end by '/' + args.logger.info("copying traindata finished....") + args.train_data = 'cache/dataset/' # the args.train_data must end by '/' + + dataset, batch_num = create_BRDNetDataset(args.train_data, args.sigma, \ + args.channel, args.batch_size, args.group_size, args.rank, shuffle=True) + + args.steps_per_epoch = int(batch_num / args.batch_size / args.group_size) + + model = BRDNet(args.channel) + + if args.pretrain: + if args.use_modelarts: + import moxing as mox + args.logger.info("copying pretrain model from obs to cache....") + mox.file.copy_parallel(args.pretrain, 'cache/pretrain') + args.logger.info("copying pretrain model finished....") + args.pretrain = 'cache/pretrain/'+args.pretrain.split('/')[-1] + + args.logger.info('loading pre-trained model {} into network'.format(args.pretrain)) + load_param_into_net(model, load_checkpoint(args.pretrain)) + args.logger.info('loaded pre-trained model {} into network'.format(args.pretrain)) + + + model = BRDWithLossCell(model) + model.set_train() + + lr_list = get_lr(args.steps_per_epoch, args.epoch, args.lr) + optimizer = nn.Adam(params=model.trainable_params(), learning_rate=Tensor(lr_list, mindspore.float32)) + model = TrainingWrapper(model, optimizer) + + model = Model(model) + + # define callbacks + if args.rank == 0: + time_cb = TimeMonitor(data_size=batch_num) + loss_cb = LossMonitor(per_print_times=10) + callbacks = [time_cb, loss_cb] + else: + callbacks = None + if args.rank_save_ckpt_flag: + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch*args.save_every, + keep_checkpoint_max=args.ckpt_save_max) + save_ckpt_path = os.path.join(save_dir, 'ckpt_' + str(args.rank) + '/') + ckpt_cb = ModelCheckpoint(config=ckpt_config, + directory=save_ckpt_path, + prefix='channel_'+str(args.channel)+'_sigma_'+str(args.sigma)+'_rank_'+str(args.rank)) + callbacks.append(ckpt_cb) + + model.train(args.epoch, dataset, callbacks=callbacks, dataset_sink_mode=True) + + args.logger.info("training finished....") + if args.use_modelarts: + args.logger.info("copying files from cache to obs....") + mox.file.copy_parallel(save_dir, args.outer_path) + args.logger.info("copying finished....") + +if __name__ == '__main__': + + train() + args.logger.info('All task finished!')