!13062 Add IPT model

From: @xiaoan95
Reviewed-by: @c_34
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-03-10 15:09:34 +08:00 committed by Gitee
commit bc6ad21278
13 changed files with 2327 additions and 0 deletions

View File

@ -0,0 +1,68 @@
"""eval script"""
# Copyright 2021 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
from src import ipt
from src.args import args
from src.data.srdata import SRData
from src.metrics import calc_psnr, quantize
from mindspore import context
import mindspore.dataset as de
from mindspore.train.serialization import load_checkpoint, load_param_into_net
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=0)
def main():
"""eval"""
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False
train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
train_de_dataset = de.GeneratorDataset(train_dataset, ['LR', "HR"], shuffle=False)
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
train_loader = train_de_dataset.create_dict_iterator()
net_m = ipt.IPT(args)
print('load mindspore net successfully.')
if args.pth_path:
param_dict = load_checkpoint(args.pth_path)
load_param_into_net(net_m, param_dict)
net_m.set_train(False)
num_imgs = train_de_dataset.get_dataset_size()
psnrs = np.zeros((num_imgs, 1))
for batch_idx, imgs in enumerate(train_loader):
lr = imgs['LR']
hr = imgs['HR']
hr_np = np.float32(hr.asnumpy())
pred = net_m.infrc(lr)
pred_np = np.float32(pred.asnumpy())
pred_np = quantize(pred_np, 255)
psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True)
psnrs[batch_idx, 0] = psnr
if args.denoise:
print('Mean psnr of %s DN_%s is %.4f' % (args.data_test[0], args.sigma, psnrs.mean(axis=0)[0]))
elif args.derain:
print('Mean psnr of Derain is %.4f' % (psnrs.mean(axis=0)))
else:
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
if __name__ == '__main__':
print("Start main function!")
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

View File

@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.vitm import ViT
def IPT(*args, **kwargs):
return ViT(*args, **kwargs)
def create_network(name, *args, **kwargs):
if name == 'IPT':
return IPT(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,147 @@
<TOC>
# Pre-Trained Image Processing Transformer (IPT)
This repository is an official implementation of the paper "Pre-Trained Image Processing Transformer" from CVPR 2021.
We study the low-level computer vision task (e.g., denoising, super-resolution and deraining) and develop a new pre-trained model, namely, image processing transformer (IPT). To maximally excavate the capability of transformer, we present to utilize the well-known ImageNet benchmark for generating a large amount of corrupted image pairs. The IPT model is trained on these images with multi-heads and multi-tails. In addition, the contrastive learning is introduced for well adapting to different image processing tasks. The pre-trained model can therefore efficiently employed on desired task after fine-tuning. With only one pre-trained model, IPT outperforms the current state-of-the-art methods on various low-level benchmarks.
If you find our work useful in your research or publication, please cite our work:
[1] Hanting Chen, Yunhe Wang, Tianyu Guo, Chang Xu, Yiping Deng, Zhenhua Liu, Siwei Ma, Chunjing Xu, Chao Xu, and Wen Gao. **"Pre-trained image processing transformer"**. <i>**CVPR 2021**.</i> [[arXiv](https://arxiv.org/abs/2012.00364)]
@inproceedings{chen2020pre,
title={Pre-trained image processing transformer},
author={Chen, Hanting and Wang, Yunhe and Guo, Tianyu and Xu, Chang and Deng, Yiping and Liu, Zhenhua and Ma, Siwei and Xu, Chunjing and Xu, Chao and Gao, Wen},
booktitle={CVPR},
year={2021}
}
## Model architecture
### The overall network architecture of IPT is shown as below:
![architecture](./ipt.png)
## Dataset
The benchmark datasets can be downloaded as follows:
For super-resolution:
Set5,
[Set14](https://sites.google.com/site/romanzeyde/research-interests),
[B100](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/),
[Urban100](https://sites.google.com/site/jbhuang0604/publications/struct_sr).
For denoising:
[CBSD68](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/).
For deraining:
[Rain100L](https://www.icst.pku.edu.cn/struct/Projects/joint_rain_removal.html)
The result images are converted into YCbCr color space. The PSNR is evaluated on the Y channel only.
## Requirements
### Hardware (GPU)
> Prepare hardware environment with GPU.
### Framework
> [MindSpore](https://www.mindspore.cn/install/en)
### For more information, please check the resources below
[MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
[MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
## Script Description
> This is the inference script of IPT, you can following steps to finish the test of image processing tasks, like SR, denoise and derain, via the corresponding pretrained models.
### Scripts and Sample Code
```
IPT
├── eval.py # inference entry
├── image
│   └── ipt.png # the illustration of IPT network
├── model
│   ├── IPT_denoise30.ckpt # denoise model weights for noise level 30
│   ├── IPT_denoise50.ckpt # denoise model weights for noise level 50
│   ├── IPT_derain.ckpt # derain model weights
│   ├── IPT_sr2.ckpt # X2 super-resolution model weights
│   ├── IPT_sr3.ckpt # X3 super-resolution model weights
│   └── IPT_sr4.ckpt # X4 super-resolution model weights
├── readme.md # Readme
├── scripts
│   └── run_eval.sh # inference script for all tasks
└── src
├── args.py # options/hyper-parameters of IPT
├── data
│   ├── common.py # common dataset
│   ├── __init__.py # Class data init function
│   └── srdata.py # flow of loading sr data
├── foldunfold_stride.py # function of fold and unfold operations for images
├── metrics.py # PSNR calculator
├── template.py # setting of model selection
└── vitm.py # IPT network
```
### Script Parameter
> For details about hyperparameters, see src/args.py.
## Evaluation
### Evaluation Process
> Inference example:
> For SR x4:
```
python eval.py --dir_data ../../data/ --data_test Set14 --nochange --test_only --ext img --chop_new --scale 4 --pth_path ./model/IPT_sr4.ckpt
```
> Or one can run following script for all tasks.
```
sh scripts/run_eval.sh
```
### Evaluation Result
The result are evaluated by the value of PSNR (Peak Signal-to-Noise Ratio), and the format is as following.
```
result: {"Mean psnr of Se5 x4 is 32.68"}
```
## Performance
### Inference Performance
The Results on all tasks are listed as below.
Super-resolution results:
| Scale | Set5 | Set14 | B100 | Urban100 |
| ----- | ----- | ----- | ----- | ----- |
| ×2 | 38.36 | 34.54 | 32.50 | 33.88 |
| ×3 | 34.83 | 30.96 | 29.39 | 29.59 |
| ×4 | 32.68 | 29.01 | 27.81 | 27.24 |
Denoising results:
| noisy level | CBSD68 | Urban100 |
| ----- | ----- | ----- |
| 30 | 32.37 | 33.82 |
| 50 | 29.94 | 31.56 |
Derain results:
| Task | Rain100L |
| ----- | ----- |
| Derain | 41.98 |
## ModeZoo Homepage
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,31 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
DATA_DIR=$2
DATA_SET=$3
PATH_CHECKPOINT=$4
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 4 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 3 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 2 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
##denoise
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 30 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 50 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
##derain
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --derain --derain_test 1 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &

View File

@ -0,0 +1,239 @@
'''args'''
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
from src import template
parser = argparse.ArgumentParser(description='EDSR and MDSR')
parser.add_argument('--debug', action='store_true',
help='Enables debug mode')
parser.add_argument('--template', default='.',
help='You can set various templates in option.py')
# Hardware specifications
parser.add_argument('--n_threads', type=int, default=6,
help='number of threads for data loading')
parser.add_argument('--cpu', action='store_true',
help='use cpu only')
parser.add_argument('--n_GPUs', type=int, default=1,
help='number of GPUs')
parser.add_argument('--seed', type=int, default=1,
help='random seed')
# Data specifications
parser.add_argument('--dir_data', type=str, default='/cache/data/',
help='dataset directory')
parser.add_argument('--dir_demo', type=str, default='../test',
help='demo image directory')
parser.add_argument('--data_train', type=str, default='DIV2K',
help='train dataset name')
parser.add_argument('--data_test', type=str, default='DIV2K',
help='test dataset name')
parser.add_argument('--data_range', type=str, default='1-800/801-810',
help='train/test data range')
parser.add_argument('--ext', type=str, default='sep',
help='dataset file extension')
parser.add_argument('--scale', type=str, default='4',
help='super resolution scale')
parser.add_argument('--patch_size', type=int, default=48,
help='output patch size')
parser.add_argument('--rgb_range', type=int, default=255,
help='maximum value of RGB')
parser.add_argument('--n_colors', type=int, default=3,
help='number of color channels to use')
parser.add_argument('--chop', action='store_true',
help='enable memory-efficient forward')
parser.add_argument('--no_augment', action='store_true',
help='do not use data augmentation')
# Model specifications
parser.add_argument('--model', default='vtip',
help='model name')
parser.add_argument('--act', type=str, default='relu',
help='activation function')
parser.add_argument('--pre_train', type=str, default='',
help='pre-trained model directory')
parser.add_argument('--extend', type=str, default='.',
help='pre-trained model directory')
parser.add_argument('--n_resblocks', type=int, default=16,
help='number of residual blocks')
parser.add_argument('--n_feats', type=int, default=64,
help='number of feature maps')
parser.add_argument('--res_scale', type=float, default=1,
help='residual scaling')
parser.add_argument('--shift_mean', default=True,
help='subtract pixel mean from the input')
parser.add_argument('--dilation', action='store_true',
help='use dilated convolution')
parser.add_argument('--precision', type=str, default='single',
choices=('single', 'half'),
help='FP precision for test (single | half)')
# Option for Residual dense network (RDN)
parser.add_argument('--G0', type=int, default=64,
help='default number of filters. (Use in RDN)')
parser.add_argument('--RDNkSize', type=int, default=3,
help='default kernel size. (Use in RDN)')
parser.add_argument('--RDNconfig', type=str, default='B',
help='parameters config of RDN. (Use in RDN)')
# Option for Residual channel attention network (RCAN)
parser.add_argument('--n_resgroups', type=int, default=10,
help='number of residual groups')
parser.add_argument('--reduction', type=int, default=16,
help='number of feature maps reduction')
# Training specifications
parser.add_argument('--reset', action='store_true',
help='reset the training')
parser.add_argument('--test_every', type=int, default=1000,
help='do test per every N batches')
parser.add_argument('--epochs', type=int, default=300,
help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=16,
help='input batch size for training')
parser.add_argument('--test_batch_size', type=int, default=1,
help='input batch size for training')
parser.add_argument('--split_batch', type=int, default=1,
help='split the batch into smaller chunks')
parser.add_argument('--self_ensemble', action='store_true',
help='use self-ensemble method for test')
parser.add_argument('--test_only', action='store_true',
help='set this option to test the model')
parser.add_argument('--gan_k', type=int, default=1,
help='k value for adversarial loss')
# Optimization specifications
parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate')
parser.add_argument('--decay', type=str, default='200',
help='learning rate decay type')
parser.add_argument('--gamma', type=float, default=0.5,
help='learning rate decay factor for step decay')
parser.add_argument('--optimizer', default='ADAM',
choices=('SGD', 'ADAM', 'RMSprop'),
help='optimizer to use (SGD | ADAM | RMSprop)')
parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum')
parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
help='ADAM beta')
parser.add_argument('--epsilon', type=float, default=1e-8,
help='ADAM epsilon for numerical stability')
parser.add_argument('--weight_decay', type=float, default=0,
help='weight decay')
parser.add_argument('--gclip', type=float, default=0,
help='gradient clipping threshold (0 = no clipping)')
# Loss specifications
parser.add_argument('--loss', type=str, default='1*L1',
help='loss function configuration')
parser.add_argument('--skip_threshold', type=float, default='1e8',
help='skipping batch that has large error')
# Log specifications
parser.add_argument('--save', type=str, default='/cache/results/edsr_baseline_x2/',
help='file name to save')
parser.add_argument('--load', type=str, default='',
help='file name to load')
parser.add_argument('--resume', type=int, default=0,
help='resume from specific checkpoint')
parser.add_argument('--save_models', action='store_true',
help='save all intermediate models')
parser.add_argument('--print_every', type=int, default=100,
help='how many batches to wait before logging training status')
parser.add_argument('--save_results', action='store_true',
help='save output results')
parser.add_argument('--save_gt', action='store_true',
help='save low-resolution and high-resolution images together')
parser.add_argument('--scalelr', type=int, default=0)
# cloud
parser.add_argument('--moxfile', type=int, default=1)
parser.add_argument('--imagenet', type=int, default=0)
parser.add_argument('--data_url', type=str, help='path to dataset')
parser.add_argument('--train_url', type=str, help='train_dir')
parser.add_argument('--pretrain', type=str, default='')
parser.add_argument('--pth_path', type=str, default='')
parser.add_argument('--load_query', type=int, default=0)
# transformer
parser.add_argument('--patch_dim', type=int, default=3)
parser.add_argument('--num_heads', type=int, default=12)
parser.add_argument('--num_layers', type=int, default=12)
parser.add_argument('--dropout_rate', type=float, default=0)
parser.add_argument('--no_norm', action='store_true')
parser.add_argument('--post_norm', action='store_true')
parser.add_argument('--no_mlp', action='store_true')
parser.add_argument('--test', action='store_true')
parser.add_argument('--chop_new', action='store_true')
parser.add_argument('--pos_every', action='store_true')
parser.add_argument('--no_pos', action='store_true')
parser.add_argument('--num_queries', type=int, default=6)
parser.add_argument('--reweight', action='store_true')
# denoise
parser.add_argument('--denoise', action='store_true')
parser.add_argument('--sigma', type=float, default=25)
# derain
parser.add_argument('--derain', action='store_true')
parser.add_argument('--finetune', action='store_true')
parser.add_argument('--derain_test', type=int, default=10)
# alltask
parser.add_argument('--alltask', action='store_true')
# dehaze
parser.add_argument('--dehaze', action='store_true')
parser.add_argument('--dehaze_test', type=int, default=100)
parser.add_argument('--indoor', action='store_true')
parser.add_argument('--outdoor', action='store_true')
parser.add_argument('--nochange', action='store_true')
# deblur
parser.add_argument('--deblur', action='store_true')
parser.add_argument('--deblur_test', type=int, default=1000)
# distribute
parser.add_argument('--init_method', type=str,
default=None, help='master address')
parser.add_argument('--rank', type=int, default=0,
help='Index of current task')
parser.add_argument('--world_size', type=int, default=1,
help='Total number of tasks')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use.')
parser.add_argument('--dist-url', default='', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--distribute', action='store_true')
args, unparsed = parser.parse_known_args()
template.set_template(args)
args.scale = [int(x) for x in args.scale.split("+")]
args.data_train = args.data_train.split('+')
args.data_test = args.data_test.split('+')
if args.epochs == 0:
args.epochs = 1e8
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False

View File

@ -0,0 +1,35 @@
"""data"""
# Copyright 2021 Huawei Technologies Co., Ltd"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from importlib import import_module
class Data:
"""data"""
def __init__(self, args):
self.loader_train = None
self.loader_test = []
for d in args.data_test:
if d in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109', 'CBSD68', 'Rain100L', 'GOPRO_Large']:
m = import_module('data.benchmark')
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
else:
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
m = import_module('data.' + module_name.lower())
testset = getattr(m, module_name)(args, train=False, name=d)
self.loader_test.append(
testset
)

View File

@ -0,0 +1,93 @@
"""common"""
# Copyright 2021 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import random
import numpy as np
import skimage.color as sc
def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
"""common"""
ih, iw = args[0].shape[:2]
tp = patch_size
ip = tp // scale
ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)
if not input_large:
tx, ty = scale * ix, scale * iy
else:
tx, ty = ix, iy
ret = [
args[0][iy:iy + ip, ix:ix + ip, :],
*[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
]
return ret
def set_channel(*args, n_channels=3):
"""common"""
def _set_channel(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
c = img.shape[2]
if n_channels == 1 and c == 3:
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
elif n_channels == 3 and c == 1:
img = np.concatenate([img] * n_channels, 2)
return img[:, :, :n_channels]
return [_set_channel(a) for a in args]
def np2Tensor(*args, rgb_range=255):
"""common"""
def _np2Tensor(img):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
tensor = np_transpose.astype(np.float32)
tensor = tensor * (rgb_range / 255)
# tensor = torch.from_numpy(np_transpose).float()
# tensor.mul_(rgb_range / 255)
return tensor
return [_np2Tensor(a) for a in args]
def augment(*args, hflip=True, rot=True):
"""common"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(a) for a in args]

View File

@ -0,0 +1,301 @@
"""srdata"""
# Copyright 2021 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import glob
import random
import pickle
from src.data import common
import numpy as np
import imageio
def search(root, target="JPEG"):
"""srdata"""
item_list = []
items = os.listdir(root)
for item in items:
path = os.path.join(root, item)
if os.path.isdir(path):
item_list.extend(search(path, target))
elif path.split('/')[-1].startswith(target):
item_list.append(path)
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
item_list.append(path)
else:
item_list = []
return item_list
def search_dehaze(root, target="JPEG"):
"""srdata"""
item_list = []
items = os.listdir(root)
for item in items:
path = os.path.join(root, item)
if os.path.isdir(path):
extend_list = search_dehaze(path, target)
if extend_list is not None:
item_list.extend(extend_list)
elif path.split('/')[-2].endswith(target):
item_list.append(path)
return item_list
class SRData():
"""srdata"""
def __init__(self, args, name='', train=True, benchmark=False):
self.args = args
self.name = name
self.train = train
self.split = 'train' if train else 'test'
self.do_eval = True
self.benchmark = benchmark
self.input_large = (args.model == 'VDSR')
self.scale = args.scale
self.idx_scale = 0
if self.args.derain:
self.derain_test = os.path.join(args.dir_data, "Rain100L")
self.derain_lr_test = search(self.derain_test, "rain")
self.derain_hr_test = [path.replace(
"rainy/", "no") for path in self.derain_lr_test]
self._set_filesystem(args.dir_data)
if args.ext.find('img') < 0:
path_bin = os.path.join(self.apath, 'bin')
os.makedirs(path_bin, exist_ok=True)
list_hr, list_lr = self._scan()
if args.ext.find('img') >= 0 or benchmark:
self.images_hr, self.images_lr = list_hr, list_lr
elif args.ext.find('sep') >= 0:
os.makedirs(
self.dir_hr.replace(self.apath, path_bin),
exist_ok=True
)
for s in self.scale:
if s == 1:
os.makedirs(
os.path.join(self.dir_hr),
exist_ok=True
)
else:
os.makedirs(
os.path.join(
self.dir_lr.replace(self.apath, path_bin),
'X{}'.format(s)
),
exist_ok=True
)
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
for h in list_hr:
b = h.replace(self.apath, path_bin)
b = b.replace(self.ext[0], '.pt')
self.images_hr.append(b)
self._check_and_load(args.ext, h, b, verbose=True)
for i, ll in enumerate(list_lr):
for l in ll:
b = l.replace(self.apath, path_bin)
b = b.replace(self.ext[1], '.pt')
self.images_lr[i].append(b)
self._check_and_load(args.ext, l, b, verbose=True)
# Below functions as used to prepare images
def _scan(self):
"""srdata"""
names_hr = sorted(
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
)
names_lr = [[] for _ in self.scale]
for f in names_hr:
filename, _ = os.path.splitext(os.path.basename(f))
for si, s in enumerate(self.scale):
if s != 1:
scale = s
names_lr[si].append(os.path.join(
self.dir_lr, 'X{}/{}x{}{}'.format(
s, filename, scale, self.ext[1]
)
))
for si, s in enumerate(self.scale):
if s == 1:
names_lr[si] = names_hr
return names_hr, names_lr
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, self.name[0])
self.dir_hr = os.path.join(self.apath, 'HR')
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
self.ext = ('.png', '.png')
def _check_and_load(self, ext, img, f, verbose=True):
if not os.path.isfile(f) or ext.find('reset') >= 0:
if verbose:
print('Making a binary: {}'.format(f))
with open(f, 'wb') as _f:
pickle.dump(imageio.imread(img), _f)
def __getitem__(self, idx):
if self.args.model == 'vtip' and self.args.derain and self.scale[
self.idx_scale] == 1 and not self.args.finetune:
norain, rain, _ = self._load_rain_test(idx)
pair = common.set_channel(
*[rain, norain], n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
return pair_t[0], pair_t[1]
if self.args.model == 'vtip' and self.args.denoise and self.scale[self.idx_scale] == 1:
hr, _ = self._load_file_hr(idx)
pair = self.get_patch_hr(hr)
pair = common.set_channel(*[pair], n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
noise = np.random.randn(*pair_t[0].shape) * self.args.sigma
lr = pair_t[0] + noise
lr = np.float32(np.clip(lr, 0, 255))
return lr, pair_t[0]
lr, hr, _ = self._load_file(idx)
pair = self.get_patch(lr, hr)
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
return pair_t[0], pair_t[1]
def __len__(self):
if self.train:
return len(self.images_hr) * self.repeat
if self.args.derain and not self.args.alltask:
return int(len(self.derain_hr_test) / self.args.derain_test)
return len(self.images_hr)
def _get_index(self, idx):
"""srdata"""
if self.train:
return idx % len(self.images_hr)
return idx
def _load_file_hr(self, idx):
"""srdata"""
idx = self._get_index(idx)
f_hr = self.images_hr[idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
if self.args.ext == 'img' or self.benchmark:
hr = imageio.imread(f_hr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f:
hr = pickle.load(_f)
return hr, filename
def _load_rain(self, idx, rain_img=False):
"""srdata"""
idx = random.randint(0, len(self.derain_img_list) - 1)
f_lr = self.derain_img_list[idx]
if rain_img:
norain = imageio.imread(f_lr.replace("rainstreak", "norain"))
rain = imageio.imread(f_lr.replace("rainstreak", "rain"))
return norain, rain
lr = imageio.imread(f_lr)
return lr
def _load_rain_test(self, idx):
"""srdata"""
f_hr = self.derain_hr_test[idx]
f_lr = self.derain_lr_test[idx]
filename, _ = os.path.splitext(os.path.basename(f_lr))
norain = imageio.imread(f_hr)
rain = imageio.imread(f_lr)
return norain, rain, filename
def _load_denoise(self, idx):
"""srdata"""
idx = self._get_index(idx)
f_lr = self.images_hr[idx]
norain = imageio.imread(f_lr)
rain = imageio.imread(f_lr.replace("HR", "LR_bicubic"))
return norain, rain
def _load_file(self, idx):
"""srdata"""
idx = self._get_index(idx)
f_hr = self.images_hr[idx]
f_lr = self.images_lr[self.idx_scale][idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
if self.args.ext == 'img' or self.benchmark:
hr = imageio.imread(f_hr)
lr = imageio.imread(f_lr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f:
hr = pickle.load(_f)
with open(f_lr, 'rb') as _f:
lr = pickle.load(_f)
return lr, hr, filename
def get_patch_hr(self, hr):
"""srdata"""
if self.train:
hr = self.get_patch_img_hr(
hr,
patch_size=self.args.patch_size,
scale=1
)
return hr
def get_patch_img_hr(self, img, patch_size=96, scale=2):
"""srdata"""
ih, iw = img.shape[:2]
tp = patch_size
ip = tp // scale
ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)
ret = img[iy:iy + ip, ix:ix + ip, :]
return ret
def get_patch(self, lr, hr):
"""srdata"""
scale = self.scale[self.idx_scale]
if self.train:
lr, hr = common.get_patch(
lr, hr,
patch_size=self.args.patch_size * scale,
scale=scale,
multi=(len(self.scale) > 1)
)
if not self.args.no_augment:
lr, hr = common.augment(lr, hr)
else:
ih, iw = lr.shape[:2]
hr = hr[0:ih * scale, 0:iw * scale]
return lr, hr
def set_scale(self, idx_scale):
"""srdata"""
if not self.input_large:
self.idx_scale = idx_scale
else:
self.idx_scale = random.randint(0, len(self.scale) - 1)

View File

@ -0,0 +1,241 @@
'''stride'''
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
class _stride_unfold_(nn.Cell):
'''stride'''
def __init__(self,
kernel_size,
stride=-1):
super(_stride_unfold_, self).__init__()
if stride == -1:
self.stride = kernel_size
else:
self.stride = stride
self.kernel_size = kernel_size
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.unfold = _unfold_(kernel_size)
def construct(self, x):
"""stride"""
N, C, H, W = x.shape
leftup_idx_x = []
leftup_idx_y = []
nh = int(H / self.stride)
nw = int(W / self.stride)
for i in range(nh):
leftup_idx_x.append(i * self.stride)
for i in range(nw):
leftup_idx_y.append(i * self.stride)
NumBlock_x = len(leftup_idx_x)
NumBlock_y = len(leftup_idx_y)
zeroslike = P.ZerosLike()
cc_2 = P.Concat(axis=2)
cc_3 = P.Concat(axis=3)
unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size,
NumBlock_y * self.kernel_size), mstype.float32)
N, C, H, W = unf_x.shape
for i in range(NumBlock_x):
for j in range(NumBlock_y):
unf_i = i * self.kernel_size
unf_j = j * self.kernel_size
org_i = leftup_idx_x[i]
org_j = leftup_idx_y[j]
fills = x[:, :, org_i:org_i + self.kernel_size,
org_j:org_j + self.kernel_size]
unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]), cc_2((cc_2(
(zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)), zeroslike(
unf_x[:, :, unf_i + self.kernel_size:, unf_j:unf_j + self.kernel_size]))))),
zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:])))
y = self.unfold(unf_x)
return y
class _stride_fold_(nn.Cell):
'''stride'''
def __init__(self,
kernel_size,
output_shape=(-1, -1),
stride=-1):
super(_stride_fold_, self).__init__()
if isinstance(kernel_size, (list, tuple)):
self.kernel_size = kernel_size
else:
self.kernel_size = [kernel_size, kernel_size]
if stride == -1:
self.stride = kernel_size[0]
else:
self.stride = stride
self.output_shape = output_shape
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.fold = _fold_(kernel_size)
def construct(self, x):
'''stride'''
if self.output_shape[0] == -1:
large_x = self.fold(x)
N, C, H, _ = large_x.shape
leftup_idx = []
for i in range(0, H, self.kernel_size[0]):
leftup_idx.append(i)
NumBlock = len(leftup_idx)
fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0],
(NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32)
for i in range(NumBlock):
for j in range(NumBlock):
fold_i = i * self.stride
fold_j = j * self.stride
org_i = leftup_idx[i]
org_j = leftup_idx[j]
fills = x[:, :, org_i:org_i + self.kernel_size[0],
org_j:org_j + self.kernel_size[1]]
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2(
(zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(
fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))),
zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:])))
y = fold_x
else:
NumBlock_x = int(
(self.output_shape[0] - self.kernel_size[0]) / self.stride + 1)
NumBlock_y = int(
(self.output_shape[1] - self.kernel_size[1]) / self.stride + 1)
large_shape = [NumBlock_x * self.kernel_size[0],
NumBlock_y * self.kernel_size[1]]
self.fold = _fold_(self.kernel_size, large_shape)
large_x = self.fold(x)
N, C, H, _ = large_x.shape
leftup_idx_x = []
leftup_idx_y = []
for i in range(NumBlock_x):
leftup_idx_x.append(i * self.kernel_size[0])
for i in range(NumBlock_y):
leftup_idx_y.append(i * self.kernel_size[1])
fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0],
(NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32)
for i in range(NumBlock_x):
for j in range(NumBlock_y):
fold_i = i * self.stride
fold_j = j * self.stride
org_i = leftup_idx_x[i]
org_j = leftup_idx_y[j]
fills = x[:, :, org_i:org_i + self.kernel_size[0],
org_j:org_j + self.kernel_size[1]]
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2(
(zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(
fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))),
zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:])))
y = fold_x
return y
class _unfold_(nn.Cell):
'''stride'''
def __init__(self,
kernel_size,
stride=-1):
super(_unfold_, self).__init__()
if stride == -1:
self.stride = kernel_size
self.kernel_size = kernel_size
self.reshape = P.Reshape()
self.transpose = P.Transpose()
def construct(self, x):
'''stride'''
N, C, H, W = x.shape
numH = int(H / self.kernel_size)
numW = int(W / self.kernel_size)
if numH * self.kernel_size != H or numW * self.kernel_size != W:
x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size]
output_img = self.reshape(x, (N, C, numH, self.kernel_size, W))
output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
output_img = self.reshape(output_img, (N, C, int(
numH * numW), self.kernel_size, self.kernel_size))
output_img = self.transpose(output_img, (0, 2, 1, 4, 3))
output_img = self.reshape(output_img, (N, int(numH * numW), -1))
return output_img
class _fold_(nn.Cell):
'''stride'''
def __init__(self,
kernel_size,
output_shape=(-1, -1),
stride=-1):
super(_fold_, self).__init__()
if isinstance(kernel_size, (list, tuple)):
self.kernel_size = kernel_size
else:
self.kernel_size = [kernel_size, kernel_size]
if stride == -1:
self.stride = kernel_size[0]
self.output_shape = output_shape
self.reshape = P.Reshape()
self.transpose = P.Transpose()
def construct(self, x):
'''stride'''
N, C, L = x.shape
org_C = int(L / self.kernel_size[0] / self.kernel_size[1])
if self.output_shape[0] == -1:
numH = int(np.sqrt(C))
numW = int(np.sqrt(C))
org_H = int(numH * self.kernel_size[0])
org_W = org_H
else:
org_H = int(self.output_shape[0])
org_W = int(self.output_shape[1])
numH = int(org_H / self.kernel_size[0])
numW = int(org_W / self.kernel_size[1])
output_img = self.reshape(
x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))
output_img = self.transpose(output_img, (0, 2, 1, 3, 4))
output_img = self.reshape(
output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1]))
output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5))
output_img = self.reshape(output_img, (N, org_C, org_H, org_W))
return output_img

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,56 @@
'''metrics'''
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import math
import numpy as np
def quantize(img, rgb_range):
'''metrics'''
pixel_range = 255 / rgb_range
img = np.multiply(img, pixel_range)
img = np.clip(img, 0, 255)
img = np.round(img) / pixel_range
return img
def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None):
'''metrics'''
hr = np.float32(hr)
sr = np.float32(sr)
diff = (sr - hr) / rgb_range
gray_coeffs = np.array([65.738, 129.057, 25.064]
).reshape((1, 3, 1, 1)) / 256
diff = np.multiply(diff, gray_coeffs).sum(1)
if hr.size == 1:
return 0
if scale != 1:
shave = scale
else:
shave = scale + 6
if scale == 1:
valid = diff
else:
valid = diff[..., shave:-shave, shave:-shave]
mse = np.mean(pow(valid, 2))
return -10 * math.log10(mse)
def rgb2ycbcr(img, y_only=True):
'''metrics'''
img.astype(np.float32)
if y_only:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
return rlt

View File

@ -0,0 +1,67 @@
'''temp'''
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
def set_template(args):
'''temp'''
if args.template.find('jpeg') >= 0:
args.data_train = 'DIV2K_jpeg'
args.data_test = 'DIV2K_jpeg'
args.epochs = 200
args.decay = '100'
if args.template.find('EDSR_paper') >= 0:
args.model = 'EDSR'
args.n_resblocks = 32
args.n_feats = 256
args.res_scale = 0.1
if args.template.find('MDSR') >= 0:
args.model = 'MDSR'
args.patch_size = 48
args.epochs = 650
if args.template.find('DDBPN') >= 0:
args.model = 'DDBPN'
args.patch_size = 128
args.scale = '4'
args.data_test = 'Set5'
args.batch_size = 20
args.epochs = 1000
args.decay = '500'
args.gamma = 0.1
args.weight_decay = 1e-4
args.loss = '1*MSE'
if args.template.find('GAN') >= 0:
args.epochs = 200
args.lr = 5e-5
args.decay = '150'
if args.template.find('RCAN') >= 0:
args.model = 'RCAN'
args.n_resgroups = 10
args.n_resblocks = 20
args.n_feats = 64
args.chop = True
if args.template.find('VDSR') >= 0:
args.model = 'VDSR'
args.n_resblocks = 20
args.n_feats = 64
args.patch_size = 41
args.lr = 1e-1