forked from OSSInnovation/mindspore
!12807 MindSpore社区网络模型征集活动——SRCNN
From: @fuyu-wang Reviewed-by: Signed-off-by:
This commit is contained in:
commit
2a7916b7bc
|
@ -0,0 +1,175 @@
|
||||||
|
# Contents
|
||||||
|
|
||||||
|
- [SRCNN Description](#srcnn-description)
|
||||||
|
- [Model Architecture](#model-architecture)
|
||||||
|
- [Dataset](#dataset)
|
||||||
|
- [Environment Requirements](#environment-requirements)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Script Description](#script-description)
|
||||||
|
- [Script and Sample Code](#script-and-sample-code)
|
||||||
|
- [Script Parameters](#script-parameters)
|
||||||
|
- [Training Process](#training-process)
|
||||||
|
- [Evaluation Process](#evaluation-process)
|
||||||
|
- [Model Description](#model-description)
|
||||||
|
- [Performance](#performance)
|
||||||
|
- [Training Performance](#evaluation-performance)
|
||||||
|
- [Inference Performance](#evaluation-performance)
|
||||||
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
# [NASNet Description](#contents)
|
||||||
|
|
||||||
|
SRCNN learns an end-to-end mapping between low- and high-resolution images, with little extra pre/post-processing beyond the optimization. With a lightweight structure, the SRCNN has achieved superior performance than the state-of-the-art methods.
|
||||||
|
|
||||||
|
[Paper](https://arxiv.org/pdf/1501.00092.pdf): Chao Dong, Chen Change Loy, Kaiming He, Xiaoou Tang. Image Super-Resolution Using Deep Convolutional Networks. 2014.
|
||||||
|
|
||||||
|
# [Model architecture](#contents)
|
||||||
|
|
||||||
|
The overall network architecture of SRCNN is show below:
|
||||||
|
|
||||||
|
[Link](https://arxiv.org/pdf/1501.00092.pdf)
|
||||||
|
|
||||||
|
# [Dataset](#contents)
|
||||||
|
|
||||||
|
- Training Dataset
|
||||||
|
- ILSVRC2013_DET_train: 395918 images, 200 classes
|
||||||
|
- Evaluation Dataset
|
||||||
|
- Set5: 5 images
|
||||||
|
- Set14: 14 images
|
||||||
|
- Set5 & Set14 download url: http://vllab.ucmerced.edu/wlai24/LapSRN/results/SR_testing_datasets.zip
|
||||||
|
- BSDS200: 200 images
|
||||||
|
- BSDS200 download url: http://vllab.ucmerced.edu/wlai24/LapSRN/results/SR_training_datasets.zip
|
||||||
|
- Data format: RGB images.
|
||||||
|
- Note: Data will be processed in src/dataset.py
|
||||||
|
|
||||||
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
|
- Hardware GPU
|
||||||
|
- Prepare hardware environment with GPU processor.
|
||||||
|
- Framework
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- For more information, please check the resources below:
|
||||||
|
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||||
|
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||||
|
|
||||||
|
# [Script description](#contents)
|
||||||
|
|
||||||
|
## [Script and sample code](#contents)
|
||||||
|
|
||||||
|
```python
|
||||||
|
.
|
||||||
|
└─srcnn
|
||||||
|
├─README.md
|
||||||
|
├─scripts
|
||||||
|
├─run_distribute_train_gpu.sh # launch distributed training with gpu platform
|
||||||
|
└─run_eval_gpu.sh # launch evaluating with gpu platform
|
||||||
|
├─src
|
||||||
|
├─config.py # parameter configuration
|
||||||
|
├─dataset.py # data preprocessing
|
||||||
|
├─metric.py # accuracy metric
|
||||||
|
├─utils.py # some functions which is commonly used
|
||||||
|
├─srcnn.py # network definition
|
||||||
|
├─create_dataset.py # generating mindrecord training dataset
|
||||||
|
├─eval.py # eval net
|
||||||
|
└─train.py # train net
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Script Parameters](#contents)
|
||||||
|
|
||||||
|
Parameters for both training and evaluating can be set in config.py.
|
||||||
|
|
||||||
|
```python
|
||||||
|
'lr': 1e-4, # learning rate
|
||||||
|
'patch_size': 33, # patch_size
|
||||||
|
'stride': 99, # stride
|
||||||
|
'scale': 2, # image scale
|
||||||
|
'epoch_size': 20, # total epoch numbers
|
||||||
|
'batch_size': 16, # input batchsize
|
||||||
|
'save_checkpoint': True, # whether saving ckpt file
|
||||||
|
'keep_checkpoint_max': 10, # max numbers to keep checkpoints
|
||||||
|
'save_checkpoint_path': 'outputs/' # save checkpoint path
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Training Process](#contents)
|
||||||
|
|
||||||
|
### Dataset
|
||||||
|
|
||||||
|
To create dataset, download the training dataset firstly and then convert them to mindrecord files. We can deal with it as follows.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python create_dataset.py --src_folder=/dataset/ILSVRC2013_DET_train --output_folder=/dataset/mindrecord_dir
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
GPU:
|
||||||
|
sh run_distribute_train_gpu.sh DEVICE_NUM VISIABLE_DEVICES(0,1,2,3,4,5,6,7) DATASET_PATH
|
||||||
|
```
|
||||||
|
|
||||||
|
### Launch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# distributed training example(8p) for GPU
|
||||||
|
sh run_distribute_train_gpu.sh 8 0,1,2,3,4,5,6,7 /dataset/train
|
||||||
|
# standalone training example for GPU
|
||||||
|
sh run_distribute_train_gpu.sh 1 0 /dataset/train
|
||||||
|
```
|
||||||
|
|
||||||
|
You can find checkpoint file together with result in log.
|
||||||
|
|
||||||
|
## [Evaluation Process](#contents)
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Evaluation
|
||||||
|
sh run_eval_gpu.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH
|
||||||
|
```
|
||||||
|
|
||||||
|
### Launch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Evaluation with checkpoint
|
||||||
|
sh run_eval_gpu.sh 1 /dataset/val /ckpt_dir/srcnn-20_*.ckpt
|
||||||
|
```
|
||||||
|
|
||||||
|
### Result
|
||||||
|
|
||||||
|
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.
|
||||||
|
|
||||||
|
result {'PSNR': 36.72421418219669}
|
||||||
|
|
||||||
|
# [Model description](#contents)
|
||||||
|
|
||||||
|
## [Performance](#contents)
|
||||||
|
|
||||||
|
### Training Performance
|
||||||
|
|
||||||
|
| Parameters | SRCNN |
|
||||||
|
| -------------------------- | ------------------------- |
|
||||||
|
| Resource | NV PCIE V100-32G |
|
||||||
|
| uploaded Date | 03/02/2021 |
|
||||||
|
| MindSpore Version | master |
|
||||||
|
| Dataset | ImageNet2013 scale:2 |
|
||||||
|
| Training Parameters | src/config.py |
|
||||||
|
| Optimizer | Adam |
|
||||||
|
| Loss Function | MSELoss |
|
||||||
|
| Loss | 0.00179 |
|
||||||
|
| Total time | 1 h 8ps |
|
||||||
|
| Checkpoint for Fine tuning | 671 K(.ckpt file) |
|
||||||
|
|
||||||
|
### Inference Performance
|
||||||
|
|
||||||
|
| Parameters | |
|
||||||
|
| -------------------------- | -------------------------- |
|
||||||
|
| Resource | NV PCIE V100-32G |
|
||||||
|
| uploaded Date | 03/02/2021 |
|
||||||
|
| MindSpore Version | master |
|
||||||
|
| Dataset | Set5/Set14/BSDS200 scale:2 |
|
||||||
|
| batch_size | 1 |
|
||||||
|
| PSNR | 36.72/32.58/33.81 |
|
||||||
|
|
||||||
|
# [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,82 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Create Dataset."""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import numpy as np
|
||||||
|
import PIL.Image as pil_image
|
||||||
|
from PIL import ImageFile
|
||||||
|
|
||||||
|
from mindspore.mindrecord import FileWriter
|
||||||
|
|
||||||
|
from src.config import srcnn_cfg as config
|
||||||
|
from src.utils import convert_rgb_to_y
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Generate dataset file.')
|
||||||
|
parser.add_argument("--src_folder", type=str, required=True, help="Raw data folder.")
|
||||||
|
parser.add_argument("--output_folder", type=str, required=True, help="Dataset output path.")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
if not os.path.exists(args.output_folder):
|
||||||
|
os.makedirs(args.output_folder)
|
||||||
|
prefix = "srcnn.mindrecord"
|
||||||
|
file_num = 32
|
||||||
|
patch_size = config.patch_size
|
||||||
|
stride = config.stride
|
||||||
|
scale = config.scale
|
||||||
|
mindrecord_path = os.path.join(args.output_folder, prefix)
|
||||||
|
writer = FileWriter(mindrecord_path, file_num)
|
||||||
|
|
||||||
|
srcnn_json = {
|
||||||
|
"lr": {"type": "float32", "shape": [1, patch_size, patch_size]},
|
||||||
|
"hr": {"type": "float32", "shape": [1, patch_size, patch_size]},
|
||||||
|
}
|
||||||
|
writer.add_schema(srcnn_json, "srcnn_json")
|
||||||
|
image_list = []
|
||||||
|
file_list = sorted(os.listdir(args.src_folder))
|
||||||
|
for file_name in file_list:
|
||||||
|
path = os.path.join(args.src_folder, file_name)
|
||||||
|
if os.path.isfile(path):
|
||||||
|
image_list.append(path)
|
||||||
|
else:
|
||||||
|
for image_path in sorted(glob.glob('{}/*'.format(path))):
|
||||||
|
image_list.append(image_path)
|
||||||
|
|
||||||
|
print("image_list size ", len(image_list), flush=True)
|
||||||
|
|
||||||
|
for path in image_list:
|
||||||
|
hr = pil_image.open(path).convert('RGB')
|
||||||
|
hr_width = (hr.width // scale) * scale
|
||||||
|
hr_height = (hr.height // scale) * scale
|
||||||
|
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
|
||||||
|
lr = hr.resize((hr_width // scale, hr_height // scale), resample=pil_image.BICUBIC)
|
||||||
|
lr = lr.resize((lr.width * scale, lr.height * scale), resample=pil_image.BICUBIC)
|
||||||
|
hr = np.array(hr).astype(np.float32)
|
||||||
|
lr = np.array(lr).astype(np.float32)
|
||||||
|
hr = convert_rgb_to_y(hr)
|
||||||
|
lr = convert_rgb_to_y(lr)
|
||||||
|
|
||||||
|
for i in range(0, lr.shape[0] - patch_size + 1, stride):
|
||||||
|
for j in range(0, lr.shape[1] - patch_size + 1, stride):
|
||||||
|
lr_res = np.expand_dims(lr[i:i + patch_size, j:j + patch_size] / 255., 0)
|
||||||
|
hr_res = np.expand_dims(hr[i:i + patch_size, j:j + patch_size] / 255., 0)
|
||||||
|
row = {"lr": lr_res, "hr": hr_res}
|
||||||
|
writer.write_raw_data([row])
|
||||||
|
|
||||||
|
writer.commit()
|
||||||
|
print("Finish!")
|
|
@ -0,0 +1,55 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""srcnn evaluation"""
|
||||||
|
import argparse
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import context, Tensor
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
from src.config import srcnn_cfg as config
|
||||||
|
from src.dataset import create_eval_dataset
|
||||||
|
from src.srcnn import SRCNN
|
||||||
|
from src.metric import SRCNNpsnr
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description="srcnn eval")
|
||||||
|
parser.add_argument('--dataset_path', type=str, required=True, help="Dataset, default is None.")
|
||||||
|
parser.add_argument('--checkpoint_path', type=str, required=True, help="checkpoint file path")
|
||||||
|
parser.add_argument('--device_target', type=str, default='GPU', choices=("GPU"),
|
||||||
|
help="Device target, support GPU.")
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
if args.device_target == "GPU":
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
device_target=args.device_target,
|
||||||
|
save_graphs=False)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported device target.")
|
||||||
|
|
||||||
|
eval_ds = create_eval_dataset(args.dataset_path)
|
||||||
|
|
||||||
|
net = SRCNN()
|
||||||
|
lr = Tensor(config.lr, ms.float32)
|
||||||
|
opt = nn.Adam(params=net.trainable_params(), learning_rate=lr, eps=1e-07)
|
||||||
|
loss = nn.MSELoss(reduction='mean')
|
||||||
|
param_dict = load_checkpoint(args.checkpoint_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
net.set_train(False)
|
||||||
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'PSNR': SRCNNpsnr()})
|
||||||
|
|
||||||
|
res = model.eval(eval_ds, dataset_sink_mode=False)
|
||||||
|
print("result ", res)
|
|
@ -0,0 +1 @@
|
||||||
|
Pillow
|
|
@ -0,0 +1,66 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# -lt 3 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_distribute_train_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [PRE_TRAINED](optional)"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $1 -lt 1 ] && [ $1 -gt 8 ]
|
||||||
|
then
|
||||||
|
echo "error: DEVICE_NUM=$1 is not in (1-8)"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
export DEVICE_NUM=$1
|
||||||
|
export RANK_SIZE=$1
|
||||||
|
|
||||||
|
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||||
|
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||||
|
if [ -d "train_parallel" ];
|
||||||
|
then
|
||||||
|
rm -rf train_parallel
|
||||||
|
fi
|
||||||
|
mkdir train_parallel
|
||||||
|
cd train_parallel || exit
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="$2"
|
||||||
|
|
||||||
|
if [ -f $4 ] # pre_trained ckpt
|
||||||
|
then
|
||||||
|
if [ $1 -gt 1 ]
|
||||||
|
then
|
||||||
|
mpirun -n $1 --allow-run-as-root python3 ${BASEPATH}/../train.py \
|
||||||
|
--dataset_path=$3 \
|
||||||
|
--run_distribute=True \
|
||||||
|
--pre_trained=$4 > log.txt 2>&1 &
|
||||||
|
else
|
||||||
|
python3 ${BASEPATH}/../train.py \
|
||||||
|
--dataset_path=$3 \
|
||||||
|
--pre_trained=$4 > log.txt 2>&1 &
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
if [ $1 -gt 1 ]
|
||||||
|
then
|
||||||
|
mpirun -n $1 --allow-run-as-root python3 ${BASEPATH}/../train.py \
|
||||||
|
--run_distribute=True \
|
||||||
|
--dataset_path=$3 > log.txt 2>&1 &
|
||||||
|
else
|
||||||
|
python3 ${BASEPATH}/../train.py \
|
||||||
|
--dataset_path=$3 > log.txt 2>&1 &
|
||||||
|
fi
|
||||||
|
fi
|
|
@ -0,0 +1,43 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# -lt 3 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_eval_gpu.sh [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# check checkpoint file
|
||||||
|
if [ ! -f $3 ]
|
||||||
|
then
|
||||||
|
echo "error: CHECKPOINT_PATH=$3 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||||
|
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||||
|
|
||||||
|
if [ -d "./eval" ];
|
||||||
|
then
|
||||||
|
rm -rf ./eval
|
||||||
|
fi
|
||||||
|
mkdir ./eval
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="$1"
|
||||||
|
|
||||||
|
python3 ${BASEPATH}/../eval.py \
|
||||||
|
--dataset_path=$2 \
|
||||||
|
--checkpoint_path=$3 > eval/eval.log 2>&1 &
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Network parameters."""
|
||||||
|
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
|
||||||
|
srcnn_cfg = edict({
|
||||||
|
'lr': 1e-4,
|
||||||
|
'patch_size': 33,
|
||||||
|
'stride': 99,
|
||||||
|
'scale': 2,
|
||||||
|
'epoch_size': 20,
|
||||||
|
'batch_size': 16,
|
||||||
|
'save_checkpoint': True,
|
||||||
|
'keep_checkpoint_max': 10,
|
||||||
|
'save_checkpoint_path': 'outputs/'
|
||||||
|
})
|
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import numpy as np
|
||||||
|
import PIL.Image as pil_image
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
|
||||||
|
from src.config import srcnn_cfg as config
|
||||||
|
from src.utils import convert_rgb_to_y
|
||||||
|
|
||||||
|
class EvalDataset:
|
||||||
|
def __init__(self, images_dir):
|
||||||
|
self.images_dir = images_dir
|
||||||
|
scale = config.scale
|
||||||
|
self.lr_group = []
|
||||||
|
self.hr_group = []
|
||||||
|
for image_path in sorted(glob.glob('{}/*'.format(images_dir))):
|
||||||
|
hr = pil_image.open(image_path).convert('RGB')
|
||||||
|
hr_width = (hr.width // scale) * scale
|
||||||
|
hr_height = (hr.height // scale) * scale
|
||||||
|
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
|
||||||
|
lr = hr.resize((hr_width // scale, hr_height // scale), resample=pil_image.BICUBIC)
|
||||||
|
lr = lr.resize((lr.width * scale, lr.height * scale), resample=pil_image.BICUBIC)
|
||||||
|
hr = np.array(hr).astype(np.float32)
|
||||||
|
lr = np.array(lr).astype(np.float32)
|
||||||
|
hr = convert_rgb_to_y(hr)
|
||||||
|
lr = convert_rgb_to_y(lr)
|
||||||
|
|
||||||
|
self.lr_group.append(lr)
|
||||||
|
self.hr_group.append(hr)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.lr_group)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return np.expand_dims(self.lr_group[idx] / 255., 0), np.expand_dims(self.hr_group[idx] / 255., 0)
|
||||||
|
|
||||||
|
def create_train_dataset(mindrecord_file, batch_size=1, shard_id=0, num_shard=1, num_parallel_workers=4):
|
||||||
|
data_set = ds.MindDataset(mindrecord_file, columns_list=["lr", "hr"], num_shards=num_shard,
|
||||||
|
shard_id=shard_id, num_parallel_workers=num_parallel_workers, shuffle=True)
|
||||||
|
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||||
|
return data_set
|
||||||
|
|
||||||
|
def create_eval_dataset(images_dir, batch_size=1):
|
||||||
|
dataset = EvalDataset(images_dir)
|
||||||
|
data_set = ds.GeneratorDataset(dataset, ["lr", "hr"], shuffle=False)
|
||||||
|
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||||
|
return data_set
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Metric for accuracy evaluation."""
|
||||||
|
from mindspore import nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class SRCNNpsnr(nn.Metric):
|
||||||
|
def __init__(self):
|
||||||
|
super(SRCNNpsnr).__init__()
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.val = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, *inputs):
|
||||||
|
if len(inputs) != 2:
|
||||||
|
raise ValueError('SRCNNpsnr need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
|
||||||
|
|
||||||
|
y_pred = self._convert_data(inputs[0])
|
||||||
|
y = self._convert_data(inputs[1])
|
||||||
|
|
||||||
|
n = len(inputs)
|
||||||
|
val = 10. * np.log10(1. / np.mean((y_pred - y) ** 2))
|
||||||
|
|
||||||
|
self.val = val
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
if self.count == 0:
|
||||||
|
raise RuntimeError('PSNR can not be calculated, because the number of samples is 0.')
|
||||||
|
return self.sum / self.count
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
|
||||||
|
class SRCNN(nn.Cell):
|
||||||
|
def __init__(self, num_channels=1):
|
||||||
|
super(SRCNN, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2, pad_mode='pad')
|
||||||
|
self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2, pad_mode='pad')
|
||||||
|
self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2, pad_mode='pad')
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.relu(self.conv1(x))
|
||||||
|
x = self.relu(self.conv2(x))
|
||||||
|
x = self.conv3(x)
|
||||||
|
return x
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def convert_rgb_to_y(img):
|
||||||
|
if isinstance(img, np.ndarray):
|
||||||
|
return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
|
||||||
|
raise Exception('Unknown Type', type(img))
|
||||||
|
|
||||||
|
def convert_rgb_to_ycbcr(img):
|
||||||
|
if isinstance(img, np.ndarray):
|
||||||
|
y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
|
||||||
|
cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
|
||||||
|
cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
|
||||||
|
return np.array([y, cb, cr]).transpose([1, 2, 0])
|
||||||
|
raise Exception('Unknown Type', type(img))
|
||||||
|
|
||||||
|
def convert_ycbcr_to_rgb(img):
|
||||||
|
if isinstance(img, np.ndarray):
|
||||||
|
r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
|
||||||
|
g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
|
||||||
|
b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
|
||||||
|
return np.array([r, g, b]).transpose([1, 2, 0])
|
||||||
|
raise Exception('Unknown Type', type(img))
|
|
@ -0,0 +1,105 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""srcnn training"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import context, Tensor
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
|
||||||
|
from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
|
from mindspore.train.model import ParallelMode
|
||||||
|
|
||||||
|
from src.config import srcnn_cfg as config
|
||||||
|
from src.dataset import create_train_dataset
|
||||||
|
from src.srcnn import SRCNN
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
|
||||||
|
"""remove useless parameters according to filter_list"""
|
||||||
|
for key in list(origin_dict.keys()):
|
||||||
|
for name in param_filter:
|
||||||
|
if name in key:
|
||||||
|
print("Delete parameter from checkpoint: ", key)
|
||||||
|
del origin_dict[key]
|
||||||
|
break
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description="srcnn training")
|
||||||
|
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||||
|
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||||
|
parser.add_argument('--device_target', type=str, default='GPU', choices=("GPU"),
|
||||||
|
help="Device target, support GPU.")
|
||||||
|
parser.add_argument('--pre_trained', type=str, default='', help='model_path, local pretrained model to load')
|
||||||
|
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
|
||||||
|
help="Run distribute, default: false.")
|
||||||
|
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
|
||||||
|
help="Filter head weight parameters, default is False.")
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
|
||||||
|
if args.device_target == "GPU":
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
device_target=args.device_target,
|
||||||
|
save_graphs=False)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported device target.")
|
||||||
|
|
||||||
|
rank = 0
|
||||||
|
device_num = 1
|
||||||
|
if args.run_distribute:
|
||||||
|
init()
|
||||||
|
rank = get_rank()
|
||||||
|
device_num = get_group_size()
|
||||||
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||||
|
|
||||||
|
train_dataset = create_train_dataset(args.dataset_path, batch_size=config.batch_size,
|
||||||
|
shard_id=rank, num_shard=device_num)
|
||||||
|
|
||||||
|
step_size = train_dataset.get_dataset_size()
|
||||||
|
|
||||||
|
# define net
|
||||||
|
net = SRCNN()
|
||||||
|
|
||||||
|
# init weight
|
||||||
|
if args.pre_trained:
|
||||||
|
param_dict = load_checkpoint(args.pre_trained)
|
||||||
|
if args.filter_weight:
|
||||||
|
filter_list = [x.name for x in net.end_point.get_parameters()]
|
||||||
|
filter_checkpoint_parameter_by_list(param_dict, filter_list)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
|
lr = Tensor(config.lr, ms.float32)
|
||||||
|
|
||||||
|
opt = nn.Adam(params=net.trainable_params(), learning_rate=lr, eps=1e-07)
|
||||||
|
loss = nn.MSELoss(reduction='mean')
|
||||||
|
model = Model(net, loss_fn=loss, optimizer=opt)
|
||||||
|
|
||||||
|
# define callbacks
|
||||||
|
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
|
||||||
|
if config.save_checkpoint and rank == 0:
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=step_size,
|
||||||
|
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||||
|
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||||
|
ckpt_cb = ModelCheckpoint(prefix="srcnn", directory=save_ckpt_path, config=config_ck)
|
||||||
|
callbacks.append(ckpt_cb)
|
||||||
|
|
||||||
|
model.train(config.epoch_size, train_dataset, callbacks=callbacks)
|
Loading…
Reference in New Issue