fix readme

This commit is contained in:
despicablemme 2021-06-10 10:42:36 +08:00
parent c5b896b328
commit 6a7e4961c4
10 changed files with 1271 additions and 0 deletions

View File

@ -0,0 +1,243 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [概述](#概述)
- [论文](#论文)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本结构与说明](#脚本结构与说明)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [用法](#用法)
- [Ascend处理器环境运行](#Ascend处理器环境运行)
- [结果](#结果)
- [评估过程](#评估过程)
- [用法](#用法-1)
- [Ascend处理器环境运行](#Ascend处理器环境运行-1)
- [结果](#结果-1)
- [推理过程](#推理过程)
- [导出MindIR](#导出MindIR)
- [在Acsend310执行推理](#在Acsend310执行推理)
- [结果](#结果)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# Learning To See In The Dark
## 概述
Leraning To See In The dark 是在2018年提出的基于全卷积神经网络FCN的一个网络模型用于图像处理。网络的主题结构为U-net将低曝光度的图像输入网络经过处理后输出得到对应的高曝光度图像实现了图像的增亮和去噪处理。
## 论文
[1] Chen C, Chen Q, Xu J, et al. Learning to See in the Dark[C]// 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition. IEEE, 2018.
# 模型架构
网络主体为Unet将raw data输入后pack成四个channel去除blacklevel并乘以ratio后输入网络主体(Unet)输出为RGB图像。
# 数据集
- 数据集地址:
- [下载Sony数据集](https://storage.googleapis.com/isl-datasets/SID/Sony.zip)
- 数据集包含了室内和室外图像。室外图像通常是在月光或街道照明条件下拍摄。在室外场景下相机的亮度一般在0.2 lux 和5 lux 之间。室内图像通常更暗。在室内场景中的相机亮度一般在0.03 lux 和0.3 lux 之间。输入图像的曝光时间设置为1/30和1/10秒。相应的参考图像 (真实图像) 的曝光时间通常会延长100到300倍即10至30秒。
- 数据集分类(文件名开头):
- 0: 训练数据集
- 1推理数据集
- 2验证数据集
- 数据集目录结构:
```text
└─dataset
├─long # label
└─short # input
```
# 环境要求
- 硬件
- 准备Ascend处理器搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
```Shell
# 分布式训练
用法sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 单机训练
用法sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 运行评估示例
用法sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
# 脚本说明
## 脚本结构与说明
```text
└──LearningToSeeInTheDark
├── README.md
├── scripts
├── run_distribute_train.sh # 启动Ascend分布式训练8卡
├── run_eval.sh # 启动Ascend评估
└── run_standalone_train.sh # 启动Ascend单机训练单卡
├── src
├── myutils.py # TrainOneStepWithLossScale & GradClip
└── unet_parts.py # 网络主题结构的部分定义
├── eval.py # 评估网络
└── train.py # 训练网络
```
# 脚本参数
- 配置超参数。
```Python
"batch_size":8, # 输入张量的批次大小
"epoch_size":3000, # 训练周期大小
"save_checkpoint":True, # 是否保存检查点
"save_checkpoint_epochs":100, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存
"keep_checkpoint_max":100, # 只保存最后一个keep_checkpoint_max检查点
"save_checkpoint_path":"./", # 检查点相对于执行路径的保存路径
"warmup_epochs":500, # 热身周期数
"lr":3e-4 # 基础学习率
"lr_end":1e-6, # 最终学习率
```
# 训练过程
## 用法
### Ascend处理器环境运行
```Shell
# 分布式训练
用法sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
# 单机训练
用法sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
```
分布式训练需要提前创建JSON格式的HCCL配置文件。
具体操作,参见[hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)中的说明。
训练结果保存在示例路径中文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果如下所示。
## 结果
```text
# 分布式训练结果8P
epoch: 1 step: 4, loss is 0.22979942
epoch: 2 step: 4, loss is 0.25466543
epoch: 3 step: 4, loss is 0.2032796
epoch: 4 step: 4, loss is 0.18603589
epoch: 5 step: 4, loss is 0.19579497
...
```
# 评估过程
## 用法
### Ascend处理器环境运行
```Shell
# 评估
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
```Shell
# 评估示例
sh run_eval.sh /data/dataset/ImageNet/imagenet_original Resnet152-140_5004.ckpt
```
## 结果
评估结果保存在示例路径中文件夹名为“eval”。您可在此路径下找到经过网络处理的输出图像。
# 推理过程
## [导出MindIR](#contents)
```shell
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
```
参数ckpt_file为必填项
`EXPORT_FORMAT` 必须在 ["AIR", "MINDIR"]中选择。
## 在Ascend310执行推理
在执行推理前mindir文件必须通过`export.py`脚本导出。以下展示了使用minir模型执行推理的示例。
```shell
# Ascend310 inference
bash export_MINDIR.sh [MINDIR_PATH] [DATASET_PATH] [DEVICE_ID]
```
- `MINDIR_PATH` mindir文件路径
- `DATASET_PATH` 推理数据集路径
- `DEVICE_ID` 可选默认值为0。
## 结果
推理结果保存在脚本执行的当前路径,你可以在当前文件夹查看输出图片。
# 模型描述
## 性能
### 评估性能
| 参数 | Ascend 910 |
|---|---|
| 模型版本 | Learning To See In The Dark |
| 资源 | Ascend 910CPU2.60GHz192核内存755G |
| 上传日期 |2021-06-21 ; |
| MindSpore版本 | 1.2.0 |
| 数据集 | SID |
| 训练参数 | epoch=2500, steps per epoch=35, batch_size = 8 |
| 优化器 | Adam |
| 损失函数 | L1loss |
| 输出 | 高亮度图像 |
| 损失 | 0.030 |
| 速度|606.12毫秒/步8卡 |
| 总时长 | 132分钟 |
| 参数(M) | 60.19 |
| 微调检查点 | 462M.ckpt文件 |
| 脚本 | [链接](https://gitee.com/alreadyhad/mindspore/tree/master/model_zoo/research/cv/LearningToSeeInTheDark) |
# 随机情况说明
unet_parts.py train_sony.py中各自设置了随机种子。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,78 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" export MINDIR"""
import argparse as arg
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, export, load_checkpoint
import mindspore.nn as nn
from src.unet_parts import DoubleConv, Down, Up, OutConv
class UNet(nn.Cell):
""" Unet """
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.inc = DoubleConv(n_channels, 32)
self.down1 = Down(32, 64)
self.down2 = Down(64, 128)
self.down3 = Down(128, 256)
self.down4 = Down(256, 512)
self.up1 = Up(512, 256)
self.up2 = Up(256, 128)
self.up3 = Up(128, 64)
self.up4 = Up(64, 32)
self.outc = OutConv(32, n_classes)
def construct(self, x):
"""Unet construct"""
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
if __name__ == '__main__':
parser = arg.ArgumentParser(description='SID export')
parser.add_argument('--device_target', type=str, choices=['Ascend', 'GPU', 'CPU'], default='Ascend',
help='device where the code will be implemented')
parser.add_argument('--device_id', type=int, default=0, help='device id')
parser.add_argument('--file_format', type=str, choices=['AIR', 'MINDIR'], default='MINDIR',
help='file format')
parser.add_argument('--checkpoint_path', required=True, default=None, help='ckpt file path')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == 'Ascend':
context.set_context(device_id=args.device_id)
ckpt_dir = args.checkpoint_path
net = UNet(4, 12)
load_checkpoint(ckpt_dir, net=net)
net.set_train(False)
input_data = Tensor(np.zeros([1, 4, 1424, 2128]), ms.float32)
export(net, input_data, file_name='sid', file_format=args.file_format)

View File

@ -0,0 +1,90 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_distribute_train.sh RANK_TABLE_FILE DATA_PATH PRETRAINED_CKPT_PATH](optional)"
echo "For example: bash run_distribute_train.sh hccl_8p_01234567_127.0.0.1.json /path/dataset"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ $# == 3 ]
then
PATH3=$(get_real_path $3)
fi
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]
then
echo "error: DATA_PATH=$PATH2 is not a directory"
exit 1
fi
if [ $# == 3 ] && [ ! -f $PATH3 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
DATA_PATH=$2
export DATA_PATH=${DATA_PATH}
for((i=0;i<${RANK_SIZE};i++))
do
rm -rf device$i
mkdir device$i
cp ../*.py ./device$i
cp *.sh ./device$i
cp -r ../src ./device$i
cd ./device$i
export DEVICE_ID=$i
export RANK_ID=$((i))
echo "start training for device $i"
env > env$i.log
if [ $# == 2 ]
then
python train_sony.py --run_distribute=True --data_url=$PATH2 &> train.log &
fi
if [ $# == 3 ]
then
python train_sony.py --run_distribute=True --data_url=$PATH2 --pre_trained=$PATH3 &> train.log &
fi
cd ../
done

View File

@ -0,0 +1,64 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_eval.sh DATA_PATH CHECKPOINT_PATH "
echo "For example: bash run.sh /path/dataset Resnet152-140_5004.ckpt"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=6
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval
env > env.log
echo "start evaluation for device $DEVICE_ID"
python test_sony.py --data_url=$PATH1 --checkpoint_path=$PATH2 &> eval.log &
cd ..

View File

@ -0,0 +1,77 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_standalone_train.sh DATA_PATH PRETRAINED_CKPT_PATH(optional)"
echo "For example: bash run_standalone_train.sh /path/dataset"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
if [ $# == 2 ]
then
PATH2=$(get_real_path $2)
fi
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ $# == 2 ] && [ ! -f $PATH2 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=6
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train
echo "start training for device $DEVICE_ID"
env > env.log
if [ $# == 1 ]
then
python train_sony.py --run_distribute=False --data_url=$PATH1 &> train.log &
fi
if [ $# == 2 ]
then
python train_sony.py --run_distribute=False --data_url=$PATH1 --pre_trained=$PATH2 &> train.log &
fi
cd ..

View File

@ -0,0 +1,25 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config = ed({
"batch_size": 8,
"total_epochs": 3000,
"warmup_epochs": 500,
"train_output_dir": "./",
})

View File

@ -0,0 +1,235 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train one step with loss scale"""
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
class WithLossCell(nn.Cell):
"""
Wrap the network with loss function to compute loss.
Args:
backbone (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss.
"""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, x, label):
""" construct of loss cell """
logits = self._backbone(x)
return self._loss_fn(logits, label)
@property
def backbone_network(self):
"""
Get the backbone network.
Returns:
Cell, return backbone network.
"""
return self._backbone
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 5
clip_grad = C.MultitypeFuncGraph("clip_grad")
class ClipGradients(nn.Cell):
"""
Clip gradients.
Returns:
List, a list of clipped_grad tuples.
"""
def __init__(self):
super(ClipGradients, self).__init__()
self.clip_by_norm = nn.ClipByNorm()
self.cast = P.Cast()
self.dtype = P.DType()
def construct(self, grads, clip_type, clip_value):
"""
Construct gradient clip network.
Args:
grads (list): List of gradient tuples.
clip_type (Tensor): The way to clip, 'value' or 'norm'.
clip_value (Tensor): Specifies how much to clip.
Returns:
List, a list of clipped_grad tuples.
"""
if clip_type not in (0, 1):
return grads
new_grads = ()
for grad in grads:
dt = self.dtype(grad)
if clip_type == 0:
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
self.cast(F.tuple_to_array((clip_value,)), dt))
else:
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
new_grads = new_grads + (t,)
return new_grads
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
Clip gradients.
Inputs:
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor], clipped gradients.
"""
if clip_type not in [0, 1]:
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
""" grad scale """
return grad * F.cast(reciprocal(scale), F.dtype(grad))
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()
@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)
class GNMTTrainOneStepWithLossScaleCell(nn.Cell):
"""
Encapsulation class of GNMT network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network: Cell. The training network. Note that loss function should have
been added.
optimizer: Optimizer. Optimizer for updating the weights.
Returns:
Tuple[Tensor, Tensor, Tensor], loss, overflow, sen.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(GNMTTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.reducer_flag = False
self.all_reduce = P.AllReduce()
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode not in ParallelMode.MODE_LIST:
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.clip_gradients = ClipGradients()
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
self.add_flags(has_effect=True)
self.loss_scalar = P.ScalarSummary()
def construct(self, inputs, labels, sens=None):
"""
network processing
overflow testing
"""
weights = self.weights
loss = self.network(inputs, labels)
# Alloc status.
init = self.alloc_status()
# Clear overflow buffer.
self.clear_before_grad(init)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
grads = self.grad(self.network, weights)(inputs, labels, self.cast(scaling_sens, mstype.float32))
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
if self.reducer_flag:
# Apply grad reducer on grads.
grads = self.grad_reducer(grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:
# Sum overflow flag over devices.
flag_reduce = self.all_reduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
self.loss_scalar("loss", loss)
ret = (loss, cond, scaling_sens)
return F.depend(ret, succ)

View File

@ -0,0 +1,100 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Unet Components"""
import mindspore.nn as nn
import mindspore.ops.operations as F
from mindspore.ops import Maximum
from mindspore.ops import DepthToSpace as dts
from mindspore.common.initializer import TruncatedNormal
from mindspore.common.initializer import XavierUniform
import mindspore as ms
ms.set_seed(1212)
class LRelu(nn.Cell):
""" activation function """
def __init__(self):
super(LRelu, self).__init__()
self.max = Maximum()
def construct(self, x):
""" construct of lrelu activation """
return self.max(x * 0.2, x)
class DoubleConv(nn.Cell):
"""conv2d for two times with lrelu activation"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super(DoubleConv, self).__init__()
if not mid_channels:
mid_channels = out_channels
self.kernel_init = XavierUniform()
self.double_conv = nn.SequentialCell(
[nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, pad_mode="same",
weight_init=self.kernel_init), LRelu(),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, pad_mode="same",
weight_init=self.kernel_init), LRelu()])
def construct(self, x):
""" construct of double conv2d """
return self.double_conv(x)
class Down(nn.Cell):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super(Down, self).__init__()
self.maxpool_conv = nn.SequentialCell(
[nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same"),
DoubleConv(in_channels, out_channels)]
)
def construct(self, x):
""" construct of down cell """
return self.maxpool_conv(x)
class Up(nn.Cell):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels):
super(Up, self).__init__()
self.concat = F.Concat(axis=1)
self.kernel_init = TruncatedNormal(0.02)
self.conv = DoubleConv(in_channels, out_channels)
self.up = nn.Conv2dTranspose(in_channels, in_channels // 2, kernel_size=2, stride=2,
pad_mode='same', weight_init=self.kernel_init)
def construct(self, x1, x2):
""" construct of up cell """
x1 = self.up(x1)
x = self.concat((x1, x2))
return self.conv(x)
class OutConv(nn.Cell):
"""trans data into RGB channels"""
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.kernel_init = XavierUniform()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, pad_mode='same', weight_init=self.kernel_init)
self.DtS = dts(block_size=2)
def construct(self, x):
""" construct of last conv """
x = self.conv(x)
x = self.DtS(x)
return x

View File

@ -0,0 +1,136 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test"""
from __future__ import division
import argparse as arg
import os
import glob
from PIL import Image
import h5py
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor, dtype
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.unet_parts import DoubleConv, Down, Up, OutConv
class UNet(nn.Cell):
""" Unet """
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.inc = DoubleConv(n_channels, 32)
self.down1 = Down(32, 64)
self.down2 = Down(64, 128)
self.down3 = Down(128, 256)
self.down4 = Down(256, 512)
self.up1 = Up(512, 256)
self.up2 = Up(256, 128)
self.up3 = Up(128, 64)
self.up4 = Up(64, 32)
self.outc = OutConv(32, n_classes)
def construct(self, x):
"""Unet construct"""
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
def pack_raw(raw):
""" pack sony raw data into 4 channels """
im = np.maximum(raw - 512, 0) / (16383 - 512) # subtract the black level
im = np.expand_dims(im, axis=2)
img_shape = im.shape
H = img_shape[0]
W = img_shape[1]
out = np.concatenate((im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :],
im[1:H:2, 0:W:2, :]), axis=2)
return out
def get_test_data(input_dir1, gt_dir1, test_ids1):
""" trans input raw data into arrays then pack into a list """
final_test_inputs = []
for test_id in test_ids1:
in_files = glob.glob(input_dir1 + '%05d_00*.hdf5' % test_id)
gt_files = glob.glob(gt_dir1 + '%05d_00*.hdf5' % test_id)
gt_path = gt_files[0]
gt_fn = os.path.basename(gt_path)
gt_exposure = float(gt_fn[9: -6])
for in_path in in_files:
in_fn = os.path.basename(in_path)
in_exposure = float(in_fn[9: -6])
ratio = min(gt_exposure / in_exposure, 300.0)
ima = h5py.File(in_path, 'r')
in_rawed = ima.get('in')[:]
input_image = np.expand_dims(pack_raw(in_rawed), axis=0) * ratio
input_image = np.minimum(input_image, 1.0)
input_image = input_image.transpose([0, 3, 1, 2])
input_image = np.float32(input_image)
final_test_inputs.append(input_image)
return final_test_inputs
if __name__ == '__main__':
parser = arg.ArgumentParser(description='Mindspore SID Eval')
parser.add_argument('--device_target', default='Ascend',
help='device where the code will be implemented')
parser.add_argument('--data_url', required=True, default=None, help='Location of data')
parser.add_argument('--checkpoint_path', required=True, default=None, help='ckpt file path')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
local_data_path = args.data_url
input_dir = os.path.join(local_data_path, 'short/')
gt_dir = os.path.join(local_data_path, 'long/')
test_fns = glob.glob(gt_dir + '1*.hdf5')
test_ids = [int(os.path.basename(test_fn)[0:5]) for test_fn in test_fns]
ckpt_dir = args.checkpoint_path
param_dict = load_checkpoint(ckpt_dir)
net = UNet(4, 12)
load_param_into_net(net, param_dict)
in_ims = get_test_data(input_dir, gt_dir, test_ids)
i = 0
for in_im in in_ims:
output = net(Tensor(in_im, dtype.float32))
output = output.asnumpy()
output = np.minimum(np.maximum(output, 0), 1)
output = np.trunc(output[0] * 255)
output = output.astype(np.int8)
output = output.transpose([1, 2, 0])
image_out = Image.fromarray(output, 'RGB')
image_out.save('output_%d.png' % i)
i += 1

View File

@ -0,0 +1,223 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train"""
from __future__ import division
import os
import glob
import argparse as arg
import ast
import h5py
import numpy as np
from mindspore import context, Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.loss import L1Loss
from mindspore.nn.dynamic_lr import piecewise_constant_lr as pc_lr
from mindspore.nn.dynamic_lr import warmup_lr
import mindspore.dataset as ds
from mindspore.communication.management import init
import mindspore.nn as nn
from mindspore.context import ParallelMode
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.unet_parts import DoubleConv, Down, Up, OutConv
from src.myutils import GNMTTrainOneStepWithLossScaleCell, WithLossCell
from src.configs import config
class UNet(nn.Cell):
""" Unet """
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.inc = DoubleConv(n_channels, 32)
self.down1 = Down(32, 64)
self.down2 = Down(64, 128)
self.down3 = Down(128, 256)
self.down4 = Down(256, 512)
self.up1 = Up(512, 256)
self.up2 = Up(256, 128)
self.up3 = Up(128, 64)
self.up4 = Up(64, 32)
self.outc = OutConv(32, n_classes)
def construct(self, x):
""" Unet construct """
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
def pack_raw(raw):
""" pack sony raw data into 4 channels """
im = np.maximum(raw - 512, 0) / (16383 - 512) # subtract the black level
im = np.expand_dims(im, axis=2)
img_shape = im.shape
H = img_shape[0]
W = img_shape[1]
out = np.concatenate((im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :],
im[1:H:2, 0:W:2, :]), axis=2)
return out
def get_dataset(input_dir1, gt_dir1, train_ids1, num_shards=None, shard_id=None, distribute=False):
""" get mindspore dataset from raw data """
input_final_data = []
gt_final_data = []
for train_id in train_ids1:
in_files = glob.glob(input_dir1 + '%05d_00*.hdf5' % train_id)
gt_files = glob.glob(gt_dir1 + '%05d_00*.hdf5' % train_id)
gt_path = gt_files[0]
gt_fn = os.path.basename(gt_path)
gt_exposure = float(gt_fn[9: -6])
gt = h5py.File(gt_path, 'r')
gt_rawed = gt.get('gt')[:]
gt_image = np.expand_dims(np.float32(gt_rawed / 65535.0), axis=0)
gt_image = gt_image.transpose([0, 3, 1, 2])
for in_path in in_files:
gt_final_data.append(gt_image[0])
in_fn = os.path.basename(in_path)
in_exposure = float(in_fn[9: -6])
ratio = min(gt_exposure / in_exposure, 300)
im = h5py.File(in_path, 'r')
in_rawed = im.get('in')[:]
input_image = np.expand_dims(pack_raw(in_rawed), axis=0) * ratio
input_image = np.float32(input_image)
input_image = input_image.transpose([0, 3, 1, 2])
input_final_data.append(input_image[0])
data = (input_final_data, gt_final_data)
if distribute:
datasets = ds.NumpySlicesDataset(data, ['input', 'label'], shuffle=False,
num_shards=num_shards, shard_id=shard_id)
else:
datasets = ds.NumpySlicesDataset(data, ['input', 'label'], shuffle=False)
return datasets
def dynamic_lr(steps_per_epoch, warmup_epochss): # if warmup, plus warmup_epochs
""" learning rate with warmup"""
milestone = [(1200 + warmup_epochss) * steps_per_epoch,
(1300 + warmup_epochss) * steps_per_epoch,
(1700 + warmup_epochss) * steps_per_epoch,
(2500 + warmup_epochss) * steps_per_epoch]
learning_rates = [3e-4, 1e-5, 3e-6, 1e-6]
lrs = pc_lr(milestone, learning_rates)
return lrs
def RandomCropAndFlip(image, label):
""" random crop and flip """
ps = 512
# random crop
h = image.shape[1]
w = image.shape[2]
xx = np.random.randint(0, h - ps)
yy = np.random.randint(0, w - ps)
image = image[:, xx:xx + ps, yy:yy + ps]
label = label[:, xx * 2:xx * 2 + ps * 2, yy * 2:yy * 2 + ps * 2]
# random flip
if np.random.randint(2) == 1: # random flip
image = np.flip(image, axis=1)
label = np.flip(label, axis=1)
if np.random.randint(2) == 1:
image = np.flip(image, axis=2)
label = np.flip(label, axis=2)
if np.random.randint(2) == 1: # random transpose
image = np.transpose(image, (0, 2, 1))
label = np.transpose(label, (0, 2, 1))
image = np.minimum(image, 1.0)
return image, label
if __name__ == "__main__":
parser = arg.ArgumentParser(description='Mindspore SID Example')
parser.add_argument('--device_target', default='Ascend',
help='device where the code will be implemented')
parser.add_argument('--data_url', required=True, default=None, help='Location of data')
parser.add_argument('--pre_trained', required=False, default=None, help='Ckpt file path')
parser.add_argument('--run_distribute', type=ast.literal_eval, required=False, default=None,
help='If run distributed')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.run_distribute:
device_num = int(os.getenv('RANK_SIZE'))
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
local_data_path = args.data_url
input_dir = os.path.join(local_data_path, 'short/')
gt_dir = os.path.join(local_data_path, 'long/')
train_fns = glob.glob(gt_dir + '0*.hdf5')
train_ids = [int(os.path.basename(train_fn)[0:5]) for train_fn in train_fns]
net = UNet(4, 12)
net_loss = L1Loss()
net = WithLossCell(net, net_loss)
if args.run_distribute:
dataset = get_dataset(input_dir, gt_dir, train_ids,
num_shards=device_num, shard_id=device_id, distribute=True)
else:
dataset = get_dataset(input_dir, gt_dir, train_ids)
transform_list = [RandomCropAndFlip]
dataset = dataset.map(transform_list, input_columns=['input', 'label'], output_columns=['input', 'label'])
dataset = dataset.shuffle(buffer_size=161)
dataset = dataset.batch(batch_size=config.batch_size, drop_remainder=True)
batches_per_epoch = dataset.get_dataset_size()
lr_warm = warmup_lr(learning_rate=3e-4, total_step=config.warmup_epochs * batches_per_epoch,
step_per_epoch=batches_per_epoch, warmup_epoch=config.warmup_epochs)
lr = dynamic_lr(batches_per_epoch, config.warmup_epochs)
lr = lr_warm + lr[config.warmup_epochs:]
net_opt = nn.Adam(net.trainable_params(), lr)
scale_manager = DynamicLossScaleManager()
net = GNMTTrainOneStepWithLossScaleCell(net, net_opt, scale_manager.get_update_cell())
ckpt_dir = args.pre_trained
if ckpt_dir is not None:
param_dict = load_checkpoint(ckpt_dir)
load_param_into_net(net, param_dict)
model = Model(net)
loss_cb = LossMonitor()
time_cb = TimeMonitor(data_size=4)
config_ck = CheckpointConfig(save_checkpoint_steps=100 * batches_per_epoch, keep_checkpoint_max=100)
ckpoint_cb = ModelCheckpoint(prefix='sony_trained_net', directory=config.train_output_dir, config=config_ck)
callbacks_list = [ckpoint_cb, loss_cb, time_cb]
model.train(epoch=config.total_epochs, train_dataset=dataset,
callbacks=callbacks_list,
dataset_sink_mode=True)