fix readme
This commit is contained in:
parent
c5b896b328
commit
6a7e4961c4
|
@ -0,0 +1,243 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [概述](#概述)
|
||||
- [论文](#论文)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本结构与说明](#脚本结构与说明)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [用法](#用法)
|
||||
- [Ascend处理器环境运行](#Ascend处理器环境运行)
|
||||
- [结果](#结果)
|
||||
- [评估过程](#评估过程)
|
||||
- [用法](#用法-1)
|
||||
- [Ascend处理器环境运行](#Ascend处理器环境运行-1)
|
||||
- [结果](#结果-1)
|
||||
- [推理过程](#推理过程)
|
||||
- [导出MindIR](#导出MindIR)
|
||||
- [在Acsend310执行推理](#在Acsend310执行推理)
|
||||
- [结果](#结果)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# Learning To See In The Dark
|
||||
|
||||
## 概述
|
||||
|
||||
Leraning To See In The dark 是在2018年提出的,基于全卷积神经网络(FCN)的一个网络模型,用于图像处理。网络的主题结构为U-net,将低曝光度的图像输入网络,经过处理后输出得到对应的高曝光度图像,实现了图像的增亮和去噪处理。
|
||||
|
||||
## 论文
|
||||
|
||||
[1] Chen C, Chen Q, Xu J, et al. Learning to See in the Dark[C]// 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition. IEEE, 2018.
|
||||
|
||||
# 模型架构
|
||||
|
||||
网络主体为Unet,将raw data输入后pack成四个channel,去除blacklevel并乘以ratio后输入网络主体(Unet),输出为RGB图像。
|
||||
|
||||
# 数据集
|
||||
|
||||
- 数据集地址:
|
||||
|
||||
- [下载Sony数据集](https://storage.googleapis.com/isl-datasets/SID/Sony.zip)
|
||||
|
||||
- 数据集包含了室内和室外图像。室外图像通常是在月光或街道照明条件下拍摄。在室外场景下,相机的亮度一般在0.2 lux 和5 lux 之间。室内图像通常更暗。在室内场景中的相机亮度一般在0.03 lux 和0.3 lux 之间。输入图像的曝光时间设置为1/30和1/10秒。相应的参考图像 (真实图像) 的曝光时间通常会延长100到300倍:即10至30秒。
|
||||
|
||||
- 数据集分类(文件名开头):
|
||||
|
||||
- 0: 训练数据集
|
||||
- 1:推理数据集
|
||||
- 2:验证数据集
|
||||
|
||||
- 数据集目录结构:
|
||||
|
||||
```text
|
||||
└─dataset
|
||||
├─long # label
|
||||
└─short # input
|
||||
```
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件
|
||||
- 准备Ascend处理器搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```Shell
|
||||
# 分布式训练
|
||||
用法:sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||
|
||||
# 单机训练
|
||||
用法:sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||
|
||||
# 运行评估示例
|
||||
用法:sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本结构与说明
|
||||
|
||||
```text
|
||||
└──LearningToSeeInTheDark
|
||||
├── README.md
|
||||
├── scripts
|
||||
├── run_distribute_train.sh # 启动Ascend分布式训练(8卡)
|
||||
├── run_eval.sh # 启动Ascend评估
|
||||
└── run_standalone_train.sh # 启动Ascend单机训练(单卡)
|
||||
├── src
|
||||
├── myutils.py # TrainOneStepWithLossScale & GradClip
|
||||
└── unet_parts.py # 网络主题结构的部分定义
|
||||
├── eval.py # 评估网络
|
||||
└── train.py # 训练网络
|
||||
```
|
||||
|
||||
# 脚本参数
|
||||
|
||||
- 配置超参数。
|
||||
|
||||
```Python
|
||||
"batch_size":8, # 输入张量的批次大小
|
||||
"epoch_size":3000, # 训练周期大小
|
||||
"save_checkpoint":True, # 是否保存检查点
|
||||
"save_checkpoint_epochs":100, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存
|
||||
"keep_checkpoint_max":100, # 只保存最后一个keep_checkpoint_max检查点
|
||||
"save_checkpoint_path":"./", # 检查点相对于执行路径的保存路径
|
||||
"warmup_epochs":500, # 热身周期数
|
||||
"lr":3e-4 # 基础学习率
|
||||
"lr_end":1e-6, # 最终学习率
|
||||
```
|
||||
|
||||
# 训练过程
|
||||
|
||||
## 用法
|
||||
|
||||
### Ascend处理器环境运行
|
||||
|
||||
```Shell
|
||||
# 分布式训练
|
||||
用法:sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||
|
||||
# 单机训练
|
||||
用法:sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||
|
||||
```
|
||||
|
||||
分布式训练需要提前创建JSON格式的HCCL配置文件。
|
||||
|
||||
具体操作,参见[hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)中的说明。
|
||||
|
||||
训练结果保存在示例路径中,文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果,如下所示。
|
||||
|
||||
## 结果
|
||||
|
||||
```text
|
||||
# 分布式训练结果(8P)
|
||||
epoch: 1 step: 4, loss is 0.22979942
|
||||
epoch: 2 step: 4, loss is 0.25466543
|
||||
epoch: 3 step: 4, loss is 0.2032796
|
||||
epoch: 4 step: 4, loss is 0.18603589
|
||||
epoch: 5 step: 4, loss is 0.19579497
|
||||
...
|
||||
```
|
||||
|
||||
# 评估过程
|
||||
|
||||
## 用法
|
||||
|
||||
### Ascend处理器环境运行
|
||||
|
||||
```Shell
|
||||
# 评估
|
||||
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
```Shell
|
||||
# 评估示例
|
||||
sh run_eval.sh /data/dataset/ImageNet/imagenet_original Resnet152-140_5004.ckpt
|
||||
```
|
||||
|
||||
## 结果
|
||||
|
||||
评估结果保存在示例路径中,文件夹名为“eval”。您可在此路径下找到经过网络处理的输出图像。
|
||||
|
||||
# 推理过程
|
||||
|
||||
## [导出MindIR](#contents)
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
参数ckpt_file为必填项,
|
||||
`EXPORT_FORMAT` 必须在 ["AIR", "MINDIR"]中选择。
|
||||
|
||||
## 在Ascend310执行推理
|
||||
|
||||
在执行推理前,mindir文件必须通过`export.py`脚本导出。以下展示了使用minir模型执行推理的示例。
|
||||
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash export_MINDIR.sh [MINDIR_PATH] [DATASET_PATH] [DEVICE_ID]
|
||||
```
|
||||
|
||||
- `MINDIR_PATH` mindir文件路径
|
||||
- `DATASET_PATH` 推理数据集路径
|
||||
- `DEVICE_ID` 可选,默认值为0。
|
||||
|
||||
## 结果
|
||||
|
||||
推理结果保存在脚本执行的当前路径,你可以在当前文件夹查看输出图片。
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 评估性能
|
||||
|
||||
| 参数 | Ascend 910 |
|
||||
|---|---|
|
||||
| 模型版本 | Learning To See In The Dark |
|
||||
| 资源 | Ascend 910;CPU:2.60GHz,192核;内存:755G |
|
||||
| 上传日期 |2021-06-21 ; |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | SID |
|
||||
| 训练参数 | epoch=2500, steps per epoch=35, batch_size = 8 |
|
||||
| 优化器 | Adam |
|
||||
| 损失函数 | L1loss |
|
||||
| 输出 | 高亮度图像 |
|
||||
| 损失 | 0.030 |
|
||||
| 速度|606.12毫秒/步(8卡) |
|
||||
| 总时长 | 132分钟 |
|
||||
| 参数(M) | 60.19 |
|
||||
| 微调检查点 | 462M(.ckpt文件) |
|
||||
| 脚本 | [链接](https://gitee.com/alreadyhad/mindspore/tree/master/model_zoo/research/cv/LearningToSeeInTheDark) |
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
unet_parts.py train_sony.py中各自设置了随机种子。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" export MINDIR"""
|
||||
import argparse as arg
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, export, load_checkpoint
|
||||
import mindspore.nn as nn
|
||||
from src.unet_parts import DoubleConv, Down, Up, OutConv
|
||||
|
||||
|
||||
class UNet(nn.Cell):
|
||||
""" Unet """
|
||||
|
||||
def __init__(self, n_channels, n_classes):
|
||||
super(UNet, self).__init__()
|
||||
self.n_channels = n_channels
|
||||
self.n_classes = n_classes
|
||||
self.inc = DoubleConv(n_channels, 32)
|
||||
self.down1 = Down(32, 64)
|
||||
self.down2 = Down(64, 128)
|
||||
self.down3 = Down(128, 256)
|
||||
self.down4 = Down(256, 512)
|
||||
self.up1 = Up(512, 256)
|
||||
self.up2 = Up(256, 128)
|
||||
self.up3 = Up(128, 64)
|
||||
self.up4 = Up(64, 32)
|
||||
self.outc = OutConv(32, n_classes)
|
||||
|
||||
def construct(self, x):
|
||||
"""Unet construct"""
|
||||
|
||||
x1 = self.inc(x)
|
||||
x2 = self.down1(x1)
|
||||
x3 = self.down2(x2)
|
||||
x4 = self.down3(x3)
|
||||
x5 = self.down4(x4)
|
||||
x = self.up1(x5, x4)
|
||||
x = self.up2(x, x3)
|
||||
x = self.up3(x, x2)
|
||||
x = self.up4(x, x1)
|
||||
logits = self.outc(x)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = arg.ArgumentParser(description='SID export')
|
||||
parser.add_argument('--device_target', type=str, choices=['Ascend', 'GPU', 'CPU'], default='Ascend',
|
||||
help='device where the code will be implemented')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id')
|
||||
parser.add_argument('--file_format', type=str, choices=['AIR', 'MINDIR'], default='MINDIR',
|
||||
help='file format')
|
||||
parser.add_argument('--checkpoint_path', required=True, default=None, help='ckpt file path')
|
||||
args = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == 'Ascend':
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
ckpt_dir = args.checkpoint_path
|
||||
net = UNet(4, 12)
|
||||
load_checkpoint(ckpt_dir, net=net)
|
||||
net.set_train(False)
|
||||
|
||||
input_data = Tensor(np.zeros([1, 4, 1424, 2128]), ms.float32)
|
||||
export(net, input_data, file_name='sid', file_format=args.file_format)
|
|
@ -0,0 +1,90 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_distribute_train.sh RANK_TABLE_FILE DATA_PATH PRETRAINED_CKPT_PATH](optional)"
|
||||
echo "For example: bash run_distribute_train.sh hccl_8p_01234567_127.0.0.1.json /path/dataset"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
PATH3=$(get_real_path $3)
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: DATA_PATH=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $# == 3 ] && [ ! -f $PATH3 ]
|
||||
then
|
||||
echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
|
||||
|
||||
DATA_PATH=$2
|
||||
export DATA_PATH=${DATA_PATH}
|
||||
|
||||
for((i=0;i<${RANK_SIZE};i++))
|
||||
do
|
||||
rm -rf device$i
|
||||
mkdir device$i
|
||||
cp ../*.py ./device$i
|
||||
cp *.sh ./device$i
|
||||
cp -r ../src ./device$i
|
||||
cd ./device$i
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$((i))
|
||||
echo "start training for device $i"
|
||||
env > env$i.log
|
||||
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
python train_sony.py --run_distribute=True --data_url=$PATH2 &> train.log &
|
||||
fi
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
python train_sony.py --run_distribute=True --data_url=$PATH2 --pre_trained=$PATH3 &> train.log &
|
||||
fi
|
||||
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,64 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_eval.sh DATA_PATH CHECKPOINT_PATH "
|
||||
echo "For example: bash run.sh /path/dataset Resnet152-140_5004.ckpt"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=6
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval
|
||||
env > env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python test_sony.py --data_url=$PATH1 --checkpoint_path=$PATH2 &> eval.log &
|
||||
cd ..
|
|
@ -0,0 +1,77 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_standalone_train.sh DATA_PATH PRETRAINED_CKPT_PATH(optional)"
|
||||
echo "For example: bash run_standalone_train.sh /path/dataset"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
PATH2=$(get_real_path $2)
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $# == 2 ] && [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=6
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
if [ $# == 1 ]
|
||||
then
|
||||
python train_sony.py --run_distribute=False --data_url=$PATH1 &> train.log &
|
||||
fi
|
||||
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
python train_sony.py --run_distribute=False --data_url=$PATH1 --pre_trained=$PATH2 &> train.log &
|
||||
fi
|
||||
cd ..
|
||||
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"batch_size": 8,
|
||||
"total_epochs": 3000,
|
||||
"warmup_epochs": 500,
|
||||
"train_output_dir": "./",
|
||||
})
|
|
@ -0,0 +1,235 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train one step with loss scale"""
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
||||
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to compute loss.
|
||||
Args:
|
||||
backbone (Cell): The target network to wrap.
|
||||
loss_fn (Cell): The loss function used to compute loss.
|
||||
"""
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
def construct(self, x, label):
|
||||
""" construct of loss cell """
|
||||
logits = self._backbone(x)
|
||||
return self._loss_fn(logits, label)
|
||||
|
||||
@property
|
||||
def backbone_network(self):
|
||||
"""
|
||||
Get the backbone network.
|
||||
Returns:
|
||||
Cell, return backbone network.
|
||||
"""
|
||||
return self._backbone
|
||||
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 5
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
||||
class ClipGradients(nn.Cell):
|
||||
"""
|
||||
Clip gradients.
|
||||
Returns:
|
||||
List, a list of clipped_grad tuples.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ClipGradients, self).__init__()
|
||||
self.clip_by_norm = nn.ClipByNorm()
|
||||
self.cast = P.Cast()
|
||||
self.dtype = P.DType()
|
||||
|
||||
def construct(self, grads, clip_type, clip_value):
|
||||
"""
|
||||
Construct gradient clip network.
|
||||
Args:
|
||||
grads (list): List of gradient tuples.
|
||||
clip_type (Tensor): The way to clip, 'value' or 'norm'.
|
||||
clip_value (Tensor): Specifies how much to clip.
|
||||
Returns:
|
||||
List, a list of clipped_grad tuples.
|
||||
"""
|
||||
if clip_type not in (0, 1):
|
||||
return grads
|
||||
new_grads = ()
|
||||
for grad in grads:
|
||||
dt = self.dtype(grad)
|
||||
if clip_type == 0:
|
||||
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
new_grads = new_grads + (t,)
|
||||
return new_grads
|
||||
|
||||
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
if clip_type not in [0, 1]:
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
""" grad scale """
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
|
||||
class GNMTTrainOneStepWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of GNMT network training.
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
Args:
|
||||
network: Cell. The training network. Note that loss function should have
|
||||
been added.
|
||||
optimizer: Optimizer. Optimizer for updating the weights.
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor, Tensor], loss, overflow, sen.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(GNMTTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.all_reduce = P.AllReduce()
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
if self.parallel_mode not in ParallelMode.MODE_LIST:
|
||||
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.clip_gradients = ClipGradients()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
self.add_flags(has_effect=True)
|
||||
self.loss_scalar = P.ScalarSummary()
|
||||
|
||||
def construct(self, inputs, labels, sens=None):
|
||||
"""
|
||||
network processing
|
||||
overflow testing
|
||||
"""
|
||||
weights = self.weights
|
||||
loss = self.network(inputs, labels)
|
||||
|
||||
# Alloc status.
|
||||
init = self.alloc_status()
|
||||
|
||||
# Clear overflow buffer.
|
||||
self.clear_before_grad(init)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
grads = self.grad(self.network, weights)(inputs, labels, self.cast(scaling_sens, mstype.float32))
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
||||
|
||||
if self.reducer_flag:
|
||||
# Apply grad reducer on grads.
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
|
||||
if self.is_distributed:
|
||||
# Sum overflow flag over devices.
|
||||
flag_reduce = self.all_reduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
|
||||
overflow = cond
|
||||
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
self.loss_scalar("loss", loss)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
|
||||
return F.depend(ret, succ)
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Unet Components"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as F
|
||||
from mindspore.ops import Maximum
|
||||
from mindspore.ops import DepthToSpace as dts
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.common.initializer import XavierUniform
|
||||
import mindspore as ms
|
||||
ms.set_seed(1212)
|
||||
|
||||
|
||||
class LRelu(nn.Cell):
|
||||
""" activation function """
|
||||
def __init__(self):
|
||||
super(LRelu, self).__init__()
|
||||
self.max = Maximum()
|
||||
|
||||
def construct(self, x):
|
||||
""" construct of lrelu activation """
|
||||
return self.max(x * 0.2, x)
|
||||
|
||||
|
||||
class DoubleConv(nn.Cell):
|
||||
"""conv2d for two times with lrelu activation"""
|
||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
||||
super(DoubleConv, self).__init__()
|
||||
if not mid_channels:
|
||||
mid_channels = out_channels
|
||||
self.kernel_init = XavierUniform()
|
||||
self.double_conv = nn.SequentialCell(
|
||||
[nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, pad_mode="same",
|
||||
weight_init=self.kernel_init), LRelu(),
|
||||
nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, pad_mode="same",
|
||||
weight_init=self.kernel_init), LRelu()])
|
||||
|
||||
def construct(self, x):
|
||||
""" construct of double conv2d """
|
||||
return self.double_conv(x)
|
||||
|
||||
|
||||
class Down(nn.Cell):
|
||||
"""Downscaling with maxpool then double conv"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(Down, self).__init__()
|
||||
self.maxpool_conv = nn.SequentialCell(
|
||||
[nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same"),
|
||||
DoubleConv(in_channels, out_channels)]
|
||||
)
|
||||
|
||||
def construct(self, x):
|
||||
""" construct of down cell """
|
||||
return self.maxpool_conv(x)
|
||||
|
||||
|
||||
class Up(nn.Cell):
|
||||
"""Upscaling then double conv"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(Up, self).__init__()
|
||||
self.concat = F.Concat(axis=1)
|
||||
self.kernel_init = TruncatedNormal(0.02)
|
||||
self.conv = DoubleConv(in_channels, out_channels)
|
||||
self.up = nn.Conv2dTranspose(in_channels, in_channels // 2, kernel_size=2, stride=2,
|
||||
pad_mode='same', weight_init=self.kernel_init)
|
||||
|
||||
def construct(self, x1, x2):
|
||||
""" construct of up cell """
|
||||
x1 = self.up(x1)
|
||||
x = self.concat((x1, x2))
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class OutConv(nn.Cell):
|
||||
"""trans data into RGB channels"""
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(OutConv, self).__init__()
|
||||
self.kernel_init = XavierUniform()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, pad_mode='same', weight_init=self.kernel_init)
|
||||
self.DtS = dts(block_size=2)
|
||||
|
||||
def construct(self, x):
|
||||
""" construct of last conv """
|
||||
x = self.conv(x)
|
||||
x = self.DtS(x)
|
||||
return x
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test"""
|
||||
from __future__ import division
|
||||
import argparse as arg
|
||||
import os
|
||||
import glob
|
||||
from PIL import Image
|
||||
import h5py
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor, dtype
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.unet_parts import DoubleConv, Down, Up, OutConv
|
||||
|
||||
|
||||
class UNet(nn.Cell):
|
||||
""" Unet """
|
||||
|
||||
def __init__(self, n_channels, n_classes):
|
||||
super(UNet, self).__init__()
|
||||
self.n_channels = n_channels
|
||||
self.n_classes = n_classes
|
||||
self.inc = DoubleConv(n_channels, 32)
|
||||
self.down1 = Down(32, 64)
|
||||
self.down2 = Down(64, 128)
|
||||
self.down3 = Down(128, 256)
|
||||
self.down4 = Down(256, 512)
|
||||
self.up1 = Up(512, 256)
|
||||
self.up2 = Up(256, 128)
|
||||
self.up3 = Up(128, 64)
|
||||
self.up4 = Up(64, 32)
|
||||
self.outc = OutConv(32, n_classes)
|
||||
|
||||
def construct(self, x):
|
||||
"""Unet construct"""
|
||||
|
||||
x1 = self.inc(x)
|
||||
x2 = self.down1(x1)
|
||||
x3 = self.down2(x2)
|
||||
x4 = self.down3(x3)
|
||||
x5 = self.down4(x4)
|
||||
x = self.up1(x5, x4)
|
||||
x = self.up2(x, x3)
|
||||
x = self.up3(x, x2)
|
||||
x = self.up4(x, x1)
|
||||
logits = self.outc(x)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def pack_raw(raw):
|
||||
""" pack sony raw data into 4 channels """
|
||||
|
||||
im = np.maximum(raw - 512, 0) / (16383 - 512) # subtract the black level
|
||||
im = np.expand_dims(im, axis=2)
|
||||
img_shape = im.shape
|
||||
H = img_shape[0]
|
||||
W = img_shape[1]
|
||||
|
||||
out = np.concatenate((im[0:H:2, 0:W:2, :],
|
||||
im[0:H:2, 1:W:2, :],
|
||||
im[1:H:2, 1:W:2, :],
|
||||
im[1:H:2, 0:W:2, :]), axis=2)
|
||||
return out
|
||||
|
||||
|
||||
def get_test_data(input_dir1, gt_dir1, test_ids1):
|
||||
""" trans input raw data into arrays then pack into a list """
|
||||
|
||||
final_test_inputs = []
|
||||
for test_id in test_ids1:
|
||||
in_files = glob.glob(input_dir1 + '%05d_00*.hdf5' % test_id)
|
||||
|
||||
gt_files = glob.glob(gt_dir1 + '%05d_00*.hdf5' % test_id)
|
||||
gt_path = gt_files[0]
|
||||
gt_fn = os.path.basename(gt_path)
|
||||
gt_exposure = float(gt_fn[9: -6])
|
||||
|
||||
for in_path in in_files:
|
||||
|
||||
in_fn = os.path.basename(in_path)
|
||||
in_exposure = float(in_fn[9: -6])
|
||||
ratio = min(gt_exposure / in_exposure, 300.0)
|
||||
ima = h5py.File(in_path, 'r')
|
||||
in_rawed = ima.get('in')[:]
|
||||
input_image = np.expand_dims(pack_raw(in_rawed), axis=0) * ratio
|
||||
input_image = np.minimum(input_image, 1.0)
|
||||
input_image = input_image.transpose([0, 3, 1, 2])
|
||||
input_image = np.float32(input_image)
|
||||
final_test_inputs.append(input_image)
|
||||
return final_test_inputs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = arg.ArgumentParser(description='Mindspore SID Eval')
|
||||
parser.add_argument('--device_target', default='Ascend',
|
||||
help='device where the code will be implemented')
|
||||
parser.add_argument('--data_url', required=True, default=None, help='Location of data')
|
||||
parser.add_argument('--checkpoint_path', required=True, default=None, help='ckpt file path')
|
||||
args = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
local_data_path = args.data_url
|
||||
input_dir = os.path.join(local_data_path, 'short/')
|
||||
gt_dir = os.path.join(local_data_path, 'long/')
|
||||
test_fns = glob.glob(gt_dir + '1*.hdf5')
|
||||
test_ids = [int(os.path.basename(test_fn)[0:5]) for test_fn in test_fns]
|
||||
ckpt_dir = args.checkpoint_path
|
||||
param_dict = load_checkpoint(ckpt_dir)
|
||||
net = UNet(4, 12)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
in_ims = get_test_data(input_dir, gt_dir, test_ids)
|
||||
i = 0
|
||||
for in_im in in_ims:
|
||||
output = net(Tensor(in_im, dtype.float32))
|
||||
output = output.asnumpy()
|
||||
output = np.minimum(np.maximum(output, 0), 1)
|
||||
output = np.trunc(output[0] * 255)
|
||||
output = output.astype(np.int8)
|
||||
output = output.transpose([1, 2, 0])
|
||||
image_out = Image.fromarray(output, 'RGB')
|
||||
image_out.save('output_%d.png' % i)
|
||||
i += 1
|
|
@ -0,0 +1,223 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train"""
|
||||
from __future__ import division
|
||||
import os
|
||||
import glob
|
||||
import argparse as arg
|
||||
import ast
|
||||
import h5py
|
||||
import numpy as np
|
||||
from mindspore import context, Model
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.nn.loss import L1Loss
|
||||
from mindspore.nn.dynamic_lr import piecewise_constant_lr as pc_lr
|
||||
from mindspore.nn.dynamic_lr import warmup_lr
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.communication.management import init
|
||||
import mindspore.nn as nn
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.unet_parts import DoubleConv, Down, Up, OutConv
|
||||
from src.myutils import GNMTTrainOneStepWithLossScaleCell, WithLossCell
|
||||
from src.configs import config
|
||||
|
||||
|
||||
class UNet(nn.Cell):
|
||||
""" Unet """
|
||||
def __init__(self, n_channels, n_classes):
|
||||
super(UNet, self).__init__()
|
||||
self.n_channels = n_channels
|
||||
self.n_classes = n_classes
|
||||
self.inc = DoubleConv(n_channels, 32)
|
||||
self.down1 = Down(32, 64)
|
||||
self.down2 = Down(64, 128)
|
||||
self.down3 = Down(128, 256)
|
||||
self.down4 = Down(256, 512)
|
||||
self.up1 = Up(512, 256)
|
||||
self.up2 = Up(256, 128)
|
||||
self.up3 = Up(128, 64)
|
||||
self.up4 = Up(64, 32)
|
||||
self.outc = OutConv(32, n_classes)
|
||||
|
||||
def construct(self, x):
|
||||
""" Unet construct """
|
||||
x1 = self.inc(x)
|
||||
x2 = self.down1(x1)
|
||||
x3 = self.down2(x2)
|
||||
x4 = self.down3(x3)
|
||||
x5 = self.down4(x4)
|
||||
x = self.up1(x5, x4)
|
||||
x = self.up2(x, x3)
|
||||
x = self.up3(x, x2)
|
||||
x = self.up4(x, x1)
|
||||
logits = self.outc(x)
|
||||
return logits
|
||||
|
||||
|
||||
def pack_raw(raw):
|
||||
""" pack sony raw data into 4 channels """
|
||||
im = np.maximum(raw - 512, 0) / (16383 - 512) # subtract the black level
|
||||
|
||||
im = np.expand_dims(im, axis=2)
|
||||
img_shape = im.shape
|
||||
H = img_shape[0]
|
||||
W = img_shape[1]
|
||||
|
||||
out = np.concatenate((im[0:H:2, 0:W:2, :],
|
||||
im[0:H:2, 1:W:2, :],
|
||||
im[1:H:2, 1:W:2, :],
|
||||
im[1:H:2, 0:W:2, :]), axis=2)
|
||||
return out
|
||||
|
||||
|
||||
def get_dataset(input_dir1, gt_dir1, train_ids1, num_shards=None, shard_id=None, distribute=False):
|
||||
""" get mindspore dataset from raw data """
|
||||
input_final_data = []
|
||||
gt_final_data = []
|
||||
for train_id in train_ids1:
|
||||
in_files = glob.glob(input_dir1 + '%05d_00*.hdf5' % train_id)
|
||||
|
||||
gt_files = glob.glob(gt_dir1 + '%05d_00*.hdf5' % train_id)
|
||||
gt_path = gt_files[0]
|
||||
gt_fn = os.path.basename(gt_path)
|
||||
gt_exposure = float(gt_fn[9: -6])
|
||||
gt = h5py.File(gt_path, 'r')
|
||||
gt_rawed = gt.get('gt')[:]
|
||||
gt_image = np.expand_dims(np.float32(gt_rawed / 65535.0), axis=0)
|
||||
gt_image = gt_image.transpose([0, 3, 1, 2])
|
||||
|
||||
for in_path in in_files:
|
||||
gt_final_data.append(gt_image[0])
|
||||
|
||||
in_fn = os.path.basename(in_path)
|
||||
in_exposure = float(in_fn[9: -6])
|
||||
ratio = min(gt_exposure / in_exposure, 300)
|
||||
im = h5py.File(in_path, 'r')
|
||||
in_rawed = im.get('in')[:]
|
||||
input_image = np.expand_dims(pack_raw(in_rawed), axis=0) * ratio
|
||||
input_image = np.float32(input_image)
|
||||
input_image = input_image.transpose([0, 3, 1, 2])
|
||||
input_final_data.append(input_image[0])
|
||||
data = (input_final_data, gt_final_data)
|
||||
if distribute:
|
||||
datasets = ds.NumpySlicesDataset(data, ['input', 'label'], shuffle=False,
|
||||
num_shards=num_shards, shard_id=shard_id)
|
||||
else:
|
||||
datasets = ds.NumpySlicesDataset(data, ['input', 'label'], shuffle=False)
|
||||
return datasets
|
||||
|
||||
|
||||
def dynamic_lr(steps_per_epoch, warmup_epochss): # if warmup, plus warmup_epochs
|
||||
""" learning rate with warmup"""
|
||||
milestone = [(1200 + warmup_epochss) * steps_per_epoch,
|
||||
(1300 + warmup_epochss) * steps_per_epoch,
|
||||
(1700 + warmup_epochss) * steps_per_epoch,
|
||||
(2500 + warmup_epochss) * steps_per_epoch]
|
||||
learning_rates = [3e-4, 1e-5, 3e-6, 1e-6]
|
||||
lrs = pc_lr(milestone, learning_rates)
|
||||
return lrs
|
||||
|
||||
|
||||
def RandomCropAndFlip(image, label):
|
||||
""" random crop and flip """
|
||||
ps = 512
|
||||
# random crop
|
||||
h = image.shape[1]
|
||||
w = image.shape[2]
|
||||
xx = np.random.randint(0, h - ps)
|
||||
yy = np.random.randint(0, w - ps)
|
||||
image = image[:, xx:xx + ps, yy:yy + ps]
|
||||
label = label[:, xx * 2:xx * 2 + ps * 2, yy * 2:yy * 2 + ps * 2]
|
||||
# random flip
|
||||
if np.random.randint(2) == 1: # random flip
|
||||
image = np.flip(image, axis=1)
|
||||
label = np.flip(label, axis=1)
|
||||
if np.random.randint(2) == 1:
|
||||
image = np.flip(image, axis=2)
|
||||
label = np.flip(label, axis=2)
|
||||
if np.random.randint(2) == 1: # random transpose
|
||||
image = np.transpose(image, (0, 2, 1))
|
||||
label = np.transpose(label, (0, 2, 1))
|
||||
image = np.minimum(image, 1.0)
|
||||
|
||||
return image, label
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = arg.ArgumentParser(description='Mindspore SID Example')
|
||||
parser.add_argument('--device_target', default='Ascend',
|
||||
help='device where the code will be implemented')
|
||||
parser.add_argument('--data_url', required=True, default=None, help='Location of data')
|
||||
parser.add_argument('--pre_trained', required=False, default=None, help='Ckpt file path')
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval, required=False, default=None,
|
||||
help='If run distributed')
|
||||
args = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
||||
if args.run_distribute:
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
|
||||
local_data_path = args.data_url
|
||||
|
||||
input_dir = os.path.join(local_data_path, 'short/')
|
||||
gt_dir = os.path.join(local_data_path, 'long/')
|
||||
|
||||
train_fns = glob.glob(gt_dir + '0*.hdf5')
|
||||
train_ids = [int(os.path.basename(train_fn)[0:5]) for train_fn in train_fns]
|
||||
|
||||
net = UNet(4, 12)
|
||||
net_loss = L1Loss()
|
||||
net = WithLossCell(net, net_loss)
|
||||
if args.run_distribute:
|
||||
dataset = get_dataset(input_dir, gt_dir, train_ids,
|
||||
num_shards=device_num, shard_id=device_id, distribute=True)
|
||||
else:
|
||||
dataset = get_dataset(input_dir, gt_dir, train_ids)
|
||||
transform_list = [RandomCropAndFlip]
|
||||
dataset = dataset.map(transform_list, input_columns=['input', 'label'], output_columns=['input', 'label'])
|
||||
dataset = dataset.shuffle(buffer_size=161)
|
||||
dataset = dataset.batch(batch_size=config.batch_size, drop_remainder=True)
|
||||
batches_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
lr_warm = warmup_lr(learning_rate=3e-4, total_step=config.warmup_epochs * batches_per_epoch,
|
||||
step_per_epoch=batches_per_epoch, warmup_epoch=config.warmup_epochs)
|
||||
lr = dynamic_lr(batches_per_epoch, config.warmup_epochs)
|
||||
lr = lr_warm + lr[config.warmup_epochs:]
|
||||
net_opt = nn.Adam(net.trainable_params(), lr)
|
||||
scale_manager = DynamicLossScaleManager()
|
||||
net = GNMTTrainOneStepWithLossScaleCell(net, net_opt, scale_manager.get_update_cell())
|
||||
|
||||
ckpt_dir = args.pre_trained
|
||||
if ckpt_dir is not None:
|
||||
param_dict = load_checkpoint(ckpt_dir)
|
||||
load_param_into_net(net, param_dict)
|
||||
model = Model(net)
|
||||
|
||||
loss_cb = LossMonitor()
|
||||
time_cb = TimeMonitor(data_size=4)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=100 * batches_per_epoch, keep_checkpoint_max=100)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='sony_trained_net', directory=config.train_output_dir, config=config_ck)
|
||||
callbacks_list = [ckpoint_cb, loss_cb, time_cb]
|
||||
model.train(epoch=config.total_epochs, train_dataset=dataset,
|
||||
callbacks=callbacks_list,
|
||||
dataset_sink_mode=True)
|
Loading…
Reference in New Issue