!13062 Add IPT model
From: @xiaoan95 Reviewed-by: @c_34 Signed-off-by: @c_34
This commit is contained in:
commit
bc6ad21278
|
@ -0,0 +1,68 @@
|
|||
"""eval script"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
from src import ipt
|
||||
from src.args import args
|
||||
from src.data.srdata import SRData
|
||||
from src.metrics import calc_psnr, quantize
|
||||
|
||||
from mindspore import context
|
||||
import mindspore.dataset as de
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=0)
|
||||
|
||||
|
||||
def main():
|
||||
"""eval"""
|
||||
for arg in vars(args):
|
||||
if vars(args)[arg] == 'True':
|
||||
vars(args)[arg] = True
|
||||
elif vars(args)[arg] == 'False':
|
||||
vars(args)[arg] = False
|
||||
train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
|
||||
train_de_dataset = de.GeneratorDataset(train_dataset, ['LR', "HR"], shuffle=False)
|
||||
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
|
||||
train_loader = train_de_dataset.create_dict_iterator()
|
||||
|
||||
net_m = ipt.IPT(args)
|
||||
print('load mindspore net successfully.')
|
||||
if args.pth_path:
|
||||
param_dict = load_checkpoint(args.pth_path)
|
||||
load_param_into_net(net_m, param_dict)
|
||||
net_m.set_train(False)
|
||||
num_imgs = train_de_dataset.get_dataset_size()
|
||||
psnrs = np.zeros((num_imgs, 1))
|
||||
for batch_idx, imgs in enumerate(train_loader):
|
||||
lr = imgs['LR']
|
||||
hr = imgs['HR']
|
||||
hr_np = np.float32(hr.asnumpy())
|
||||
pred = net_m.infrc(lr)
|
||||
pred_np = np.float32(pred.asnumpy())
|
||||
pred_np = quantize(pred_np, 255)
|
||||
psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True)
|
||||
psnrs[batch_idx, 0] = psnr
|
||||
if args.denoise:
|
||||
print('Mean psnr of %s DN_%s is %.4f' % (args.data_test[0], args.sigma, psnrs.mean(axis=0)[0]))
|
||||
elif args.derain:
|
||||
print('Mean psnr of Derain is %.4f' % (psnrs.mean(axis=0)))
|
||||
else:
|
||||
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("Start main function!")
|
||||
main()
|
Binary file not shown.
After Width: | Height: | Size: 1.7 MiB |
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.vitm import ViT
|
||||
|
||||
|
||||
def IPT(*args, **kwargs):
|
||||
return ViT(*args, **kwargs)
|
||||
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == 'IPT':
|
||||
return IPT(*args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,147 @@
|
|||
<TOC>
|
||||
|
||||
# Pre-Trained Image Processing Transformer (IPT)
|
||||
|
||||
This repository is an official implementation of the paper "Pre-Trained Image Processing Transformer" from CVPR 2021.
|
||||
|
||||
We study the low-level computer vision task (e.g., denoising, super-resolution and deraining) and develop a new pre-trained model, namely, image processing transformer (IPT). To maximally excavate the capability of transformer, we present to utilize the well-known ImageNet benchmark for generating a large amount of corrupted image pairs. The IPT model is trained on these images with multi-heads and multi-tails. In addition, the contrastive learning is introduced for well adapting to different image processing tasks. The pre-trained model can therefore efficiently employed on desired task after fine-tuning. With only one pre-trained model, IPT outperforms the current state-of-the-art methods on various low-level benchmarks.
|
||||
|
||||
If you find our work useful in your research or publication, please cite our work:
|
||||
[1] Hanting Chen, Yunhe Wang, Tianyu Guo, Chang Xu, Yiping Deng, Zhenhua Liu, Siwei Ma, Chunjing Xu, Chao Xu, and Wen Gao. **"Pre-trained image processing transformer"**. <i>**CVPR 2021**.</i> [[arXiv](https://arxiv.org/abs/2012.00364)]
|
||||
|
||||
@inproceedings{chen2020pre,
|
||||
title={Pre-trained image processing transformer},
|
||||
author={Chen, Hanting and Wang, Yunhe and Guo, Tianyu and Xu, Chang and Deng, Yiping and Liu, Zhenhua and Ma, Siwei and Xu, Chunjing and Xu, Chao and Gao, Wen},
|
||||
booktitle={CVPR},
|
||||
year={2021}
|
||||
}
|
||||
|
||||
## Model architecture
|
||||
### The overall network architecture of IPT is shown as below:
|
||||
![architecture](./ipt.png)
|
||||
|
||||
## Dataset
|
||||
|
||||
The benchmark datasets can be downloaded as follows:
|
||||
|
||||
For super-resolution:
|
||||
|
||||
Set5,
|
||||
|
||||
[Set14](https://sites.google.com/site/romanzeyde/research-interests),
|
||||
|
||||
[B100](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/),
|
||||
|
||||
[Urban100](https://sites.google.com/site/jbhuang0604/publications/struct_sr).
|
||||
|
||||
For denoising:
|
||||
|
||||
[CBSD68](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/).
|
||||
|
||||
For deraining:
|
||||
|
||||
[Rain100L](https://www.icst.pku.edu.cn/struct/Projects/joint_rain_removal.html)
|
||||
|
||||
The result images are converted into YCbCr color space. The PSNR is evaluated on the Y channel only.
|
||||
|
||||
## Requirements
|
||||
|
||||
### Hardware (GPU)
|
||||
> Prepare hardware environment with GPU.
|
||||
|
||||
### Framework
|
||||
> [MindSpore](https://www.mindspore.cn/install/en)
|
||||
### For more information, please check the resources below:
|
||||
[MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
[MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
## Script Description
|
||||
|
||||
> This is the inference script of IPT, you can following steps to finish the test of image processing tasks, like SR, denoise and derain, via the corresponding pretrained models.
|
||||
|
||||
### Scripts and Sample Code
|
||||
|
||||
```
|
||||
IPT
|
||||
├── eval.py # inference entry
|
||||
├── image
|
||||
│ └── ipt.png # the illustration of IPT network
|
||||
├── model
|
||||
│ ├── IPT_denoise30.ckpt # denoise model weights for noise level 30
|
||||
│ ├── IPT_denoise50.ckpt # denoise model weights for noise level 50
|
||||
│ ├── IPT_derain.ckpt # derain model weights
|
||||
│ ├── IPT_sr2.ckpt # X2 super-resolution model weights
|
||||
│ ├── IPT_sr3.ckpt # X3 super-resolution model weights
|
||||
│ └── IPT_sr4.ckpt # X4 super-resolution model weights
|
||||
├── readme.md # Readme
|
||||
├── scripts
|
||||
│ └── run_eval.sh # inference script for all tasks
|
||||
└── src
|
||||
├── args.py # options/hyper-parameters of IPT
|
||||
├── data
|
||||
│ ├── common.py # common dataset
|
||||
│ ├── __init__.py # Class data init function
|
||||
│ └── srdata.py # flow of loading sr data
|
||||
├── foldunfold_stride.py # function of fold and unfold operations for images
|
||||
├── metrics.py # PSNR calculator
|
||||
├── template.py # setting of model selection
|
||||
└── vitm.py # IPT network
|
||||
```
|
||||
|
||||
### Script Parameter
|
||||
|
||||
> For details about hyperparameters, see src/args.py.
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Evaluation Process
|
||||
> Inference example:
|
||||
> For SR x4:
|
||||
|
||||
```
|
||||
python eval.py --dir_data ../../data/ --data_test Set14 --nochange --test_only --ext img --chop_new --scale 4 --pth_path ./model/IPT_sr4.ckpt
|
||||
```
|
||||
|
||||
> Or one can run following script for all tasks.
|
||||
|
||||
```
|
||||
sh scripts/run_eval.sh
|
||||
```
|
||||
|
||||
### Evaluation Result
|
||||
The result are evaluated by the value of PSNR (Peak Signal-to-Noise Ratio), and the format is as following.
|
||||
|
||||
```
|
||||
result: {"Mean psnr of Se5 x4 is 32.68"}
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
### Inference Performance
|
||||
|
||||
The Results on all tasks are listed as below.
|
||||
|
||||
Super-resolution results:
|
||||
|
||||
| Scale | Set5 | Set14 | B100 | Urban100 |
|
||||
| ----- | ----- | ----- | ----- | ----- |
|
||||
| ×2 | 38.36 | 34.54 | 32.50 | 33.88 |
|
||||
| ×3 | 34.83 | 30.96 | 29.39 | 29.59 |
|
||||
| ×4 | 32.68 | 29.01 | 27.81 | 27.24 |
|
||||
|
||||
Denoising results:
|
||||
|
||||
| noisy level | CBSD68 | Urban100 |
|
||||
| ----- | ----- | ----- |
|
||||
| 30 | 32.37 | 33.82 |
|
||||
| 50 | 29.94 | 31.56 |
|
||||
|
||||
Derain results:
|
||||
|
||||
| Task | Rain100L |
|
||||
| ----- | ----- |
|
||||
| Derain | 41.98 |
|
||||
|
||||
## ModeZoo Homepage
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,31 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
DATA_SET=$3
|
||||
PATH_CHECKPOINT=$4
|
||||
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 4 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 3 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 2 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
|
||||
##denoise
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 30 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 50 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
|
||||
##derain
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --derain --derain_test 1 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
|
@ -0,0 +1,239 @@
|
|||
'''args'''
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
from src import template
|
||||
|
||||
parser = argparse.ArgumentParser(description='EDSR and MDSR')
|
||||
|
||||
parser.add_argument('--debug', action='store_true',
|
||||
help='Enables debug mode')
|
||||
parser.add_argument('--template', default='.',
|
||||
help='You can set various templates in option.py')
|
||||
|
||||
# Hardware specifications
|
||||
parser.add_argument('--n_threads', type=int, default=6,
|
||||
help='number of threads for data loading')
|
||||
parser.add_argument('--cpu', action='store_true',
|
||||
help='use cpu only')
|
||||
parser.add_argument('--n_GPUs', type=int, default=1,
|
||||
help='number of GPUs')
|
||||
parser.add_argument('--seed', type=int, default=1,
|
||||
help='random seed')
|
||||
|
||||
# Data specifications
|
||||
parser.add_argument('--dir_data', type=str, default='/cache/data/',
|
||||
help='dataset directory')
|
||||
parser.add_argument('--dir_demo', type=str, default='../test',
|
||||
help='demo image directory')
|
||||
parser.add_argument('--data_train', type=str, default='DIV2K',
|
||||
help='train dataset name')
|
||||
parser.add_argument('--data_test', type=str, default='DIV2K',
|
||||
help='test dataset name')
|
||||
parser.add_argument('--data_range', type=str, default='1-800/801-810',
|
||||
help='train/test data range')
|
||||
parser.add_argument('--ext', type=str, default='sep',
|
||||
help='dataset file extension')
|
||||
parser.add_argument('--scale', type=str, default='4',
|
||||
help='super resolution scale')
|
||||
parser.add_argument('--patch_size', type=int, default=48,
|
||||
help='output patch size')
|
||||
parser.add_argument('--rgb_range', type=int, default=255,
|
||||
help='maximum value of RGB')
|
||||
parser.add_argument('--n_colors', type=int, default=3,
|
||||
help='number of color channels to use')
|
||||
parser.add_argument('--chop', action='store_true',
|
||||
help='enable memory-efficient forward')
|
||||
parser.add_argument('--no_augment', action='store_true',
|
||||
help='do not use data augmentation')
|
||||
|
||||
# Model specifications
|
||||
parser.add_argument('--model', default='vtip',
|
||||
help='model name')
|
||||
|
||||
parser.add_argument('--act', type=str, default='relu',
|
||||
help='activation function')
|
||||
parser.add_argument('--pre_train', type=str, default='',
|
||||
help='pre-trained model directory')
|
||||
parser.add_argument('--extend', type=str, default='.',
|
||||
help='pre-trained model directory')
|
||||
parser.add_argument('--n_resblocks', type=int, default=16,
|
||||
help='number of residual blocks')
|
||||
parser.add_argument('--n_feats', type=int, default=64,
|
||||
help='number of feature maps')
|
||||
parser.add_argument('--res_scale', type=float, default=1,
|
||||
help='residual scaling')
|
||||
parser.add_argument('--shift_mean', default=True,
|
||||
help='subtract pixel mean from the input')
|
||||
parser.add_argument('--dilation', action='store_true',
|
||||
help='use dilated convolution')
|
||||
parser.add_argument('--precision', type=str, default='single',
|
||||
choices=('single', 'half'),
|
||||
help='FP precision for test (single | half)')
|
||||
|
||||
# Option for Residual dense network (RDN)
|
||||
parser.add_argument('--G0', type=int, default=64,
|
||||
help='default number of filters. (Use in RDN)')
|
||||
parser.add_argument('--RDNkSize', type=int, default=3,
|
||||
help='default kernel size. (Use in RDN)')
|
||||
parser.add_argument('--RDNconfig', type=str, default='B',
|
||||
help='parameters config of RDN. (Use in RDN)')
|
||||
|
||||
# Option for Residual channel attention network (RCAN)
|
||||
parser.add_argument('--n_resgroups', type=int, default=10,
|
||||
help='number of residual groups')
|
||||
parser.add_argument('--reduction', type=int, default=16,
|
||||
help='number of feature maps reduction')
|
||||
|
||||
# Training specifications
|
||||
parser.add_argument('--reset', action='store_true',
|
||||
help='reset the training')
|
||||
parser.add_argument('--test_every', type=int, default=1000,
|
||||
help='do test per every N batches')
|
||||
parser.add_argument('--epochs', type=int, default=300,
|
||||
help='number of epochs to train')
|
||||
parser.add_argument('--batch_size', type=int, default=16,
|
||||
help='input batch size for training')
|
||||
parser.add_argument('--test_batch_size', type=int, default=1,
|
||||
help='input batch size for training')
|
||||
parser.add_argument('--split_batch', type=int, default=1,
|
||||
help='split the batch into smaller chunks')
|
||||
parser.add_argument('--self_ensemble', action='store_true',
|
||||
help='use self-ensemble method for test')
|
||||
parser.add_argument('--test_only', action='store_true',
|
||||
help='set this option to test the model')
|
||||
parser.add_argument('--gan_k', type=int, default=1,
|
||||
help='k value for adversarial loss')
|
||||
|
||||
# Optimization specifications
|
||||
parser.add_argument('--lr', type=float, default=1e-4,
|
||||
help='learning rate')
|
||||
parser.add_argument('--decay', type=str, default='200',
|
||||
help='learning rate decay type')
|
||||
parser.add_argument('--gamma', type=float, default=0.5,
|
||||
help='learning rate decay factor for step decay')
|
||||
parser.add_argument('--optimizer', default='ADAM',
|
||||
choices=('SGD', 'ADAM', 'RMSprop'),
|
||||
help='optimizer to use (SGD | ADAM | RMSprop)')
|
||||
parser.add_argument('--momentum', type=float, default=0.9,
|
||||
help='SGD momentum')
|
||||
parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
|
||||
help='ADAM beta')
|
||||
parser.add_argument('--epsilon', type=float, default=1e-8,
|
||||
help='ADAM epsilon for numerical stability')
|
||||
parser.add_argument('--weight_decay', type=float, default=0,
|
||||
help='weight decay')
|
||||
parser.add_argument('--gclip', type=float, default=0,
|
||||
help='gradient clipping threshold (0 = no clipping)')
|
||||
|
||||
# Loss specifications
|
||||
parser.add_argument('--loss', type=str, default='1*L1',
|
||||
help='loss function configuration')
|
||||
parser.add_argument('--skip_threshold', type=float, default='1e8',
|
||||
help='skipping batch that has large error')
|
||||
|
||||
# Log specifications
|
||||
parser.add_argument('--save', type=str, default='/cache/results/edsr_baseline_x2/',
|
||||
help='file name to save')
|
||||
parser.add_argument('--load', type=str, default='',
|
||||
help='file name to load')
|
||||
parser.add_argument('--resume', type=int, default=0,
|
||||
help='resume from specific checkpoint')
|
||||
parser.add_argument('--save_models', action='store_true',
|
||||
help='save all intermediate models')
|
||||
parser.add_argument('--print_every', type=int, default=100,
|
||||
help='how many batches to wait before logging training status')
|
||||
parser.add_argument('--save_results', action='store_true',
|
||||
help='save output results')
|
||||
parser.add_argument('--save_gt', action='store_true',
|
||||
help='save low-resolution and high-resolution images together')
|
||||
|
||||
parser.add_argument('--scalelr', type=int, default=0)
|
||||
# cloud
|
||||
parser.add_argument('--moxfile', type=int, default=1)
|
||||
parser.add_argument('--imagenet', type=int, default=0)
|
||||
parser.add_argument('--data_url', type=str, help='path to dataset')
|
||||
parser.add_argument('--train_url', type=str, help='train_dir')
|
||||
parser.add_argument('--pretrain', type=str, default='')
|
||||
parser.add_argument('--pth_path', type=str, default='')
|
||||
parser.add_argument('--load_query', type=int, default=0)
|
||||
# transformer
|
||||
parser.add_argument('--patch_dim', type=int, default=3)
|
||||
parser.add_argument('--num_heads', type=int, default=12)
|
||||
parser.add_argument('--num_layers', type=int, default=12)
|
||||
parser.add_argument('--dropout_rate', type=float, default=0)
|
||||
parser.add_argument('--no_norm', action='store_true')
|
||||
parser.add_argument('--post_norm', action='store_true')
|
||||
parser.add_argument('--no_mlp', action='store_true')
|
||||
parser.add_argument('--test', action='store_true')
|
||||
parser.add_argument('--chop_new', action='store_true')
|
||||
parser.add_argument('--pos_every', action='store_true')
|
||||
parser.add_argument('--no_pos', action='store_true')
|
||||
parser.add_argument('--num_queries', type=int, default=6)
|
||||
parser.add_argument('--reweight', action='store_true')
|
||||
|
||||
# denoise
|
||||
parser.add_argument('--denoise', action='store_true')
|
||||
parser.add_argument('--sigma', type=float, default=25)
|
||||
|
||||
# derain
|
||||
parser.add_argument('--derain', action='store_true')
|
||||
parser.add_argument('--finetune', action='store_true')
|
||||
parser.add_argument('--derain_test', type=int, default=10)
|
||||
# alltask
|
||||
parser.add_argument('--alltask', action='store_true')
|
||||
|
||||
# dehaze
|
||||
parser.add_argument('--dehaze', action='store_true')
|
||||
parser.add_argument('--dehaze_test', type=int, default=100)
|
||||
parser.add_argument('--indoor', action='store_true')
|
||||
parser.add_argument('--outdoor', action='store_true')
|
||||
parser.add_argument('--nochange', action='store_true')
|
||||
# deblur
|
||||
parser.add_argument('--deblur', action='store_true')
|
||||
parser.add_argument('--deblur_test', type=int, default=1000)
|
||||
|
||||
# distribute
|
||||
parser.add_argument('--init_method', type=str,
|
||||
default=None, help='master address')
|
||||
parser.add_argument('--rank', type=int, default=0,
|
||||
help='Index of current task')
|
||||
parser.add_argument('--world_size', type=int, default=1,
|
||||
help='Total number of tasks')
|
||||
parser.add_argument('--gpu', default=None, type=int,
|
||||
help='GPU id to use.')
|
||||
parser.add_argument('--dist-url', default='', type=str,
|
||||
help='url used to set up distributed training')
|
||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
help='distributed backend')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 4)')
|
||||
parser.add_argument('--distribute', action='store_true')
|
||||
|
||||
args, unparsed = parser.parse_known_args()
|
||||
template.set_template(args)
|
||||
|
||||
args.scale = [int(x) for x in args.scale.split("+")]
|
||||
args.data_train = args.data_train.split('+')
|
||||
args.data_test = args.data_test.split('+')
|
||||
|
||||
if args.epochs == 0:
|
||||
args.epochs = 1e8
|
||||
|
||||
for arg in vars(args):
|
||||
if vars(args)[arg] == 'True':
|
||||
vars(args)[arg] = True
|
||||
elif vars(args)[arg] == 'False':
|
||||
vars(args)[arg] = False
|
|
@ -0,0 +1,35 @@
|
|||
"""data"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd"
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
class Data:
|
||||
"""data"""
|
||||
|
||||
def __init__(self, args):
|
||||
self.loader_train = None
|
||||
self.loader_test = []
|
||||
for d in args.data_test:
|
||||
if d in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109', 'CBSD68', 'Rain100L', 'GOPRO_Large']:
|
||||
m = import_module('data.benchmark')
|
||||
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
|
||||
else:
|
||||
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
|
||||
m = import_module('data.' + module_name.lower())
|
||||
testset = getattr(m, module_name)(args, train=False, name=d)
|
||||
|
||||
self.loader_test.append(
|
||||
testset
|
||||
)
|
|
@ -0,0 +1,93 @@
|
|||
"""common"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import skimage.color as sc
|
||||
|
||||
|
||||
def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
|
||||
"""common"""
|
||||
ih, iw = args[0].shape[:2]
|
||||
|
||||
tp = patch_size
|
||||
ip = tp // scale
|
||||
|
||||
ix = random.randrange(0, iw - ip + 1)
|
||||
iy = random.randrange(0, ih - ip + 1)
|
||||
|
||||
if not input_large:
|
||||
tx, ty = scale * ix, scale * iy
|
||||
else:
|
||||
tx, ty = ix, iy
|
||||
|
||||
ret = [
|
||||
args[0][iy:iy + ip, ix:ix + ip, :],
|
||||
*[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
|
||||
]
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def set_channel(*args, n_channels=3):
|
||||
"""common"""
|
||||
|
||||
def _set_channel(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
|
||||
c = img.shape[2]
|
||||
if n_channels == 1 and c == 3:
|
||||
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
|
||||
elif n_channels == 3 and c == 1:
|
||||
img = np.concatenate([img] * n_channels, 2)
|
||||
|
||||
return img[:, :, :n_channels]
|
||||
|
||||
return [_set_channel(a) for a in args]
|
||||
|
||||
|
||||
def np2Tensor(*args, rgb_range=255):
|
||||
"""common"""
|
||||
|
||||
def _np2Tensor(img):
|
||||
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
|
||||
tensor = np_transpose.astype(np.float32)
|
||||
tensor = tensor * (rgb_range / 255)
|
||||
# tensor = torch.from_numpy(np_transpose).float()
|
||||
# tensor.mul_(rgb_range / 255)
|
||||
|
||||
return tensor
|
||||
|
||||
return [_np2Tensor(a) for a in args]
|
||||
|
||||
|
||||
def augment(*args, hflip=True, rot=True):
|
||||
"""common"""
|
||||
hflip = hflip and random.random() < 0.5
|
||||
vflip = rot and random.random() < 0.5
|
||||
rot90 = rot and random.random() < 0.5
|
||||
|
||||
def _augment(img):
|
||||
if hflip:
|
||||
img = img[:, ::-1, :]
|
||||
if vflip:
|
||||
img = img[::-1, :, :]
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
return [_augment(a) for a in args]
|
|
@ -0,0 +1,301 @@
|
|||
"""srdata"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
import pickle
|
||||
|
||||
from src.data import common
|
||||
|
||||
import numpy as np
|
||||
import imageio
|
||||
|
||||
|
||||
def search(root, target="JPEG"):
|
||||
"""srdata"""
|
||||
item_list = []
|
||||
items = os.listdir(root)
|
||||
for item in items:
|
||||
path = os.path.join(root, item)
|
||||
if os.path.isdir(path):
|
||||
item_list.extend(search(path, target))
|
||||
elif path.split('/')[-1].startswith(target):
|
||||
item_list.append(path)
|
||||
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
|
||||
item_list.append(path)
|
||||
else:
|
||||
item_list = []
|
||||
return item_list
|
||||
|
||||
|
||||
def search_dehaze(root, target="JPEG"):
|
||||
"""srdata"""
|
||||
item_list = []
|
||||
items = os.listdir(root)
|
||||
for item in items:
|
||||
path = os.path.join(root, item)
|
||||
if os.path.isdir(path):
|
||||
extend_list = search_dehaze(path, target)
|
||||
if extend_list is not None:
|
||||
item_list.extend(extend_list)
|
||||
elif path.split('/')[-2].endswith(target):
|
||||
item_list.append(path)
|
||||
return item_list
|
||||
|
||||
|
||||
class SRData():
|
||||
"""srdata"""
|
||||
|
||||
def __init__(self, args, name='', train=True, benchmark=False):
|
||||
self.args = args
|
||||
self.name = name
|
||||
self.train = train
|
||||
self.split = 'train' if train else 'test'
|
||||
self.do_eval = True
|
||||
self.benchmark = benchmark
|
||||
self.input_large = (args.model == 'VDSR')
|
||||
self.scale = args.scale
|
||||
self.idx_scale = 0
|
||||
|
||||
if self.args.derain:
|
||||
self.derain_test = os.path.join(args.dir_data, "Rain100L")
|
||||
self.derain_lr_test = search(self.derain_test, "rain")
|
||||
self.derain_hr_test = [path.replace(
|
||||
"rainy/", "no") for path in self.derain_lr_test]
|
||||
self._set_filesystem(args.dir_data)
|
||||
if args.ext.find('img') < 0:
|
||||
path_bin = os.path.join(self.apath, 'bin')
|
||||
os.makedirs(path_bin, exist_ok=True)
|
||||
|
||||
list_hr, list_lr = self._scan()
|
||||
if args.ext.find('img') >= 0 or benchmark:
|
||||
self.images_hr, self.images_lr = list_hr, list_lr
|
||||
elif args.ext.find('sep') >= 0:
|
||||
os.makedirs(
|
||||
self.dir_hr.replace(self.apath, path_bin),
|
||||
exist_ok=True
|
||||
)
|
||||
for s in self.scale:
|
||||
if s == 1:
|
||||
os.makedirs(
|
||||
os.path.join(self.dir_hr),
|
||||
exist_ok=True
|
||||
)
|
||||
else:
|
||||
os.makedirs(
|
||||
os.path.join(
|
||||
self.dir_lr.replace(self.apath, path_bin),
|
||||
'X{}'.format(s)
|
||||
),
|
||||
exist_ok=True
|
||||
)
|
||||
|
||||
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
|
||||
for h in list_hr:
|
||||
b = h.replace(self.apath, path_bin)
|
||||
b = b.replace(self.ext[0], '.pt')
|
||||
self.images_hr.append(b)
|
||||
self._check_and_load(args.ext, h, b, verbose=True)
|
||||
for i, ll in enumerate(list_lr):
|
||||
for l in ll:
|
||||
b = l.replace(self.apath, path_bin)
|
||||
b = b.replace(self.ext[1], '.pt')
|
||||
self.images_lr[i].append(b)
|
||||
self._check_and_load(args.ext, l, b, verbose=True)
|
||||
|
||||
# Below functions as used to prepare images
|
||||
def _scan(self):
|
||||
"""srdata"""
|
||||
names_hr = sorted(
|
||||
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
|
||||
)
|
||||
names_lr = [[] for _ in self.scale]
|
||||
for f in names_hr:
|
||||
filename, _ = os.path.splitext(os.path.basename(f))
|
||||
for si, s in enumerate(self.scale):
|
||||
if s != 1:
|
||||
scale = s
|
||||
names_lr[si].append(os.path.join(
|
||||
self.dir_lr, 'X{}/{}x{}{}'.format(
|
||||
s, filename, scale, self.ext[1]
|
||||
)
|
||||
))
|
||||
for si, s in enumerate(self.scale):
|
||||
if s == 1:
|
||||
names_lr[si] = names_hr
|
||||
return names_hr, names_lr
|
||||
|
||||
def _set_filesystem(self, dir_data):
|
||||
self.apath = os.path.join(dir_data, self.name[0])
|
||||
self.dir_hr = os.path.join(self.apath, 'HR')
|
||||
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
|
||||
self.ext = ('.png', '.png')
|
||||
|
||||
def _check_and_load(self, ext, img, f, verbose=True):
|
||||
if not os.path.isfile(f) or ext.find('reset') >= 0:
|
||||
if verbose:
|
||||
print('Making a binary: {}'.format(f))
|
||||
with open(f, 'wb') as _f:
|
||||
pickle.dump(imageio.imread(img), _f)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.args.model == 'vtip' and self.args.derain and self.scale[
|
||||
self.idx_scale] == 1 and not self.args.finetune:
|
||||
norain, rain, _ = self._load_rain_test(idx)
|
||||
pair = common.set_channel(
|
||||
*[rain, norain], n_channels=self.args.n_colors)
|
||||
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
||||
return pair_t[0], pair_t[1]
|
||||
if self.args.model == 'vtip' and self.args.denoise and self.scale[self.idx_scale] == 1:
|
||||
hr, _ = self._load_file_hr(idx)
|
||||
pair = self.get_patch_hr(hr)
|
||||
pair = common.set_channel(*[pair], n_channels=self.args.n_colors)
|
||||
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
||||
noise = np.random.randn(*pair_t[0].shape) * self.args.sigma
|
||||
lr = pair_t[0] + noise
|
||||
lr = np.float32(np.clip(lr, 0, 255))
|
||||
return lr, pair_t[0]
|
||||
lr, hr, _ = self._load_file(idx)
|
||||
pair = self.get_patch(lr, hr)
|
||||
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
|
||||
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
||||
|
||||
return pair_t[0], pair_t[1]
|
||||
|
||||
def __len__(self):
|
||||
if self.train:
|
||||
return len(self.images_hr) * self.repeat
|
||||
|
||||
if self.args.derain and not self.args.alltask:
|
||||
return int(len(self.derain_hr_test) / self.args.derain_test)
|
||||
return len(self.images_hr)
|
||||
|
||||
def _get_index(self, idx):
|
||||
"""srdata"""
|
||||
if self.train:
|
||||
return idx % len(self.images_hr)
|
||||
return idx
|
||||
|
||||
def _load_file_hr(self, idx):
|
||||
"""srdata"""
|
||||
idx = self._get_index(idx)
|
||||
f_hr = self.images_hr[idx]
|
||||
|
||||
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
||||
if self.args.ext == 'img' or self.benchmark:
|
||||
hr = imageio.imread(f_hr)
|
||||
elif self.args.ext.find('sep') >= 0:
|
||||
with open(f_hr, 'rb') as _f:
|
||||
hr = pickle.load(_f)
|
||||
|
||||
return hr, filename
|
||||
|
||||
def _load_rain(self, idx, rain_img=False):
|
||||
"""srdata"""
|
||||
idx = random.randint(0, len(self.derain_img_list) - 1)
|
||||
f_lr = self.derain_img_list[idx]
|
||||
if rain_img:
|
||||
norain = imageio.imread(f_lr.replace("rainstreak", "norain"))
|
||||
rain = imageio.imread(f_lr.replace("rainstreak", "rain"))
|
||||
return norain, rain
|
||||
lr = imageio.imread(f_lr)
|
||||
return lr
|
||||
|
||||
def _load_rain_test(self, idx):
|
||||
"""srdata"""
|
||||
f_hr = self.derain_hr_test[idx]
|
||||
f_lr = self.derain_lr_test[idx]
|
||||
filename, _ = os.path.splitext(os.path.basename(f_lr))
|
||||
norain = imageio.imread(f_hr)
|
||||
rain = imageio.imread(f_lr)
|
||||
return norain, rain, filename
|
||||
|
||||
def _load_denoise(self, idx):
|
||||
"""srdata"""
|
||||
idx = self._get_index(idx)
|
||||
f_lr = self.images_hr[idx]
|
||||
norain = imageio.imread(f_lr)
|
||||
rain = imageio.imread(f_lr.replace("HR", "LR_bicubic"))
|
||||
return norain, rain
|
||||
|
||||
def _load_file(self, idx):
|
||||
"""srdata"""
|
||||
idx = self._get_index(idx)
|
||||
|
||||
f_hr = self.images_hr[idx]
|
||||
f_lr = self.images_lr[self.idx_scale][idx]
|
||||
|
||||
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
||||
if self.args.ext == 'img' or self.benchmark:
|
||||
hr = imageio.imread(f_hr)
|
||||
lr = imageio.imread(f_lr)
|
||||
elif self.args.ext.find('sep') >= 0:
|
||||
with open(f_hr, 'rb') as _f:
|
||||
hr = pickle.load(_f)
|
||||
with open(f_lr, 'rb') as _f:
|
||||
lr = pickle.load(_f)
|
||||
|
||||
return lr, hr, filename
|
||||
|
||||
def get_patch_hr(self, hr):
|
||||
"""srdata"""
|
||||
if self.train:
|
||||
hr = self.get_patch_img_hr(
|
||||
hr,
|
||||
patch_size=self.args.patch_size,
|
||||
scale=1
|
||||
)
|
||||
|
||||
return hr
|
||||
|
||||
def get_patch_img_hr(self, img, patch_size=96, scale=2):
|
||||
"""srdata"""
|
||||
ih, iw = img.shape[:2]
|
||||
|
||||
tp = patch_size
|
||||
ip = tp // scale
|
||||
|
||||
ix = random.randrange(0, iw - ip + 1)
|
||||
iy = random.randrange(0, ih - ip + 1)
|
||||
|
||||
ret = img[iy:iy + ip, ix:ix + ip, :]
|
||||
|
||||
return ret
|
||||
|
||||
def get_patch(self, lr, hr):
|
||||
"""srdata"""
|
||||
scale = self.scale[self.idx_scale]
|
||||
if self.train:
|
||||
lr, hr = common.get_patch(
|
||||
lr, hr,
|
||||
patch_size=self.args.patch_size * scale,
|
||||
scale=scale,
|
||||
multi=(len(self.scale) > 1)
|
||||
)
|
||||
if not self.args.no_augment:
|
||||
lr, hr = common.augment(lr, hr)
|
||||
else:
|
||||
ih, iw = lr.shape[:2]
|
||||
hr = hr[0:ih * scale, 0:iw * scale]
|
||||
|
||||
return lr, hr
|
||||
|
||||
def set_scale(self, idx_scale):
|
||||
"""srdata"""
|
||||
if not self.input_large:
|
||||
self.idx_scale = idx_scale
|
||||
else:
|
||||
self.idx_scale = random.randint(0, len(self.scale) - 1)
|
|
@ -0,0 +1,241 @@
|
|||
'''stride'''
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class _stride_unfold_(nn.Cell):
|
||||
'''stride'''
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
stride=-1):
|
||||
|
||||
super(_stride_unfold_, self).__init__()
|
||||
if stride == -1:
|
||||
self.stride = kernel_size
|
||||
else:
|
||||
self.stride = stride
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.unfold = _unfold_(kernel_size)
|
||||
|
||||
def construct(self, x):
|
||||
"""stride"""
|
||||
N, C, H, W = x.shape
|
||||
leftup_idx_x = []
|
||||
leftup_idx_y = []
|
||||
nh = int(H / self.stride)
|
||||
nw = int(W / self.stride)
|
||||
for i in range(nh):
|
||||
leftup_idx_x.append(i * self.stride)
|
||||
for i in range(nw):
|
||||
leftup_idx_y.append(i * self.stride)
|
||||
NumBlock_x = len(leftup_idx_x)
|
||||
NumBlock_y = len(leftup_idx_y)
|
||||
zeroslike = P.ZerosLike()
|
||||
cc_2 = P.Concat(axis=2)
|
||||
cc_3 = P.Concat(axis=3)
|
||||
unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size,
|
||||
NumBlock_y * self.kernel_size), mstype.float32)
|
||||
N, C, H, W = unf_x.shape
|
||||
for i in range(NumBlock_x):
|
||||
for j in range(NumBlock_y):
|
||||
unf_i = i * self.kernel_size
|
||||
unf_j = j * self.kernel_size
|
||||
org_i = leftup_idx_x[i]
|
||||
org_j = leftup_idx_y[j]
|
||||
fills = x[:, :, org_i:org_i + self.kernel_size,
|
||||
org_j:org_j + self.kernel_size]
|
||||
unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]), cc_2((cc_2(
|
||||
(zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)), zeroslike(
|
||||
unf_x[:, :, unf_i + self.kernel_size:, unf_j:unf_j + self.kernel_size]))))),
|
||||
zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:])))
|
||||
y = self.unfold(unf_x)
|
||||
return y
|
||||
|
||||
|
||||
class _stride_fold_(nn.Cell):
|
||||
'''stride'''
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
output_shape=(-1, -1),
|
||||
stride=-1):
|
||||
|
||||
super(_stride_fold_, self).__init__()
|
||||
if isinstance(kernel_size, (list, tuple)):
|
||||
self.kernel_size = kernel_size
|
||||
else:
|
||||
self.kernel_size = [kernel_size, kernel_size]
|
||||
|
||||
if stride == -1:
|
||||
self.stride = kernel_size[0]
|
||||
else:
|
||||
self.stride = stride
|
||||
|
||||
self.output_shape = output_shape
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.fold = _fold_(kernel_size)
|
||||
|
||||
def construct(self, x):
|
||||
'''stride'''
|
||||
if self.output_shape[0] == -1:
|
||||
large_x = self.fold(x)
|
||||
N, C, H, _ = large_x.shape
|
||||
leftup_idx = []
|
||||
for i in range(0, H, self.kernel_size[0]):
|
||||
leftup_idx.append(i)
|
||||
NumBlock = len(leftup_idx)
|
||||
fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0],
|
||||
(NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32)
|
||||
|
||||
for i in range(NumBlock):
|
||||
for j in range(NumBlock):
|
||||
fold_i = i * self.stride
|
||||
fold_j = j * self.stride
|
||||
org_i = leftup_idx[i]
|
||||
org_j = leftup_idx[j]
|
||||
fills = x[:, :, org_i:org_i + self.kernel_size[0],
|
||||
org_j:org_j + self.kernel_size[1]]
|
||||
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2(
|
||||
(zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(
|
||||
fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))),
|
||||
zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:])))
|
||||
y = fold_x
|
||||
else:
|
||||
NumBlock_x = int(
|
||||
(self.output_shape[0] - self.kernel_size[0]) / self.stride + 1)
|
||||
NumBlock_y = int(
|
||||
(self.output_shape[1] - self.kernel_size[1]) / self.stride + 1)
|
||||
large_shape = [NumBlock_x * self.kernel_size[0],
|
||||
NumBlock_y * self.kernel_size[1]]
|
||||
self.fold = _fold_(self.kernel_size, large_shape)
|
||||
large_x = self.fold(x)
|
||||
N, C, H, _ = large_x.shape
|
||||
leftup_idx_x = []
|
||||
leftup_idx_y = []
|
||||
for i in range(NumBlock_x):
|
||||
leftup_idx_x.append(i * self.kernel_size[0])
|
||||
for i in range(NumBlock_y):
|
||||
leftup_idx_y.append(i * self.kernel_size[1])
|
||||
fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0],
|
||||
(NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32)
|
||||
for i in range(NumBlock_x):
|
||||
for j in range(NumBlock_y):
|
||||
fold_i = i * self.stride
|
||||
fold_j = j * self.stride
|
||||
org_i = leftup_idx_x[i]
|
||||
org_j = leftup_idx_y[j]
|
||||
fills = x[:, :, org_i:org_i + self.kernel_size[0],
|
||||
org_j:org_j + self.kernel_size[1]]
|
||||
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2(
|
||||
(zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(
|
||||
fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))),
|
||||
zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:])))
|
||||
y = fold_x
|
||||
return y
|
||||
|
||||
|
||||
class _unfold_(nn.Cell):
|
||||
'''stride'''
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
stride=-1):
|
||||
|
||||
super(_unfold_, self).__init__()
|
||||
if stride == -1:
|
||||
self.stride = kernel_size
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, x):
|
||||
'''stride'''
|
||||
N, C, H, W = x.shape
|
||||
numH = int(H / self.kernel_size)
|
||||
numW = int(W / self.kernel_size)
|
||||
if numH * self.kernel_size != H or numW * self.kernel_size != W:
|
||||
x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size]
|
||||
output_img = self.reshape(x, (N, C, numH, self.kernel_size, W))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
|
||||
|
||||
output_img = self.reshape(output_img, (N, C, int(
|
||||
numH * numW), self.kernel_size, self.kernel_size))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 2, 1, 4, 3))
|
||||
|
||||
output_img = self.reshape(output_img, (N, int(numH * numW), -1))
|
||||
return output_img
|
||||
|
||||
|
||||
class _fold_(nn.Cell):
|
||||
'''stride'''
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
output_shape=(-1, -1),
|
||||
stride=-1):
|
||||
|
||||
super(_fold_, self).__init__()
|
||||
|
||||
if isinstance(kernel_size, (list, tuple)):
|
||||
self.kernel_size = kernel_size
|
||||
else:
|
||||
self.kernel_size = [kernel_size, kernel_size]
|
||||
|
||||
if stride == -1:
|
||||
self.stride = kernel_size[0]
|
||||
self.output_shape = output_shape
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, x):
|
||||
'''stride'''
|
||||
N, C, L = x.shape
|
||||
org_C = int(L / self.kernel_size[0] / self.kernel_size[1])
|
||||
if self.output_shape[0] == -1:
|
||||
numH = int(np.sqrt(C))
|
||||
numW = int(np.sqrt(C))
|
||||
org_H = int(numH * self.kernel_size[0])
|
||||
org_W = org_H
|
||||
else:
|
||||
org_H = int(self.output_shape[0])
|
||||
org_W = int(self.output_shape[1])
|
||||
numH = int(org_H / self.kernel_size[0])
|
||||
numW = int(org_W / self.kernel_size[1])
|
||||
|
||||
output_img = self.reshape(
|
||||
x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 2, 1, 3, 4))
|
||||
output_img = self.reshape(
|
||||
output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1]))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5))
|
||||
|
||||
output_img = self.reshape(output_img, (N, org_C, org_H, org_W))
|
||||
return output_img
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,56 @@
|
|||
'''metrics'''
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def quantize(img, rgb_range):
|
||||
'''metrics'''
|
||||
pixel_range = 255 / rgb_range
|
||||
img = np.multiply(img, pixel_range)
|
||||
img = np.clip(img, 0, 255)
|
||||
img = np.round(img) / pixel_range
|
||||
return img
|
||||
|
||||
|
||||
def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None):
|
||||
'''metrics'''
|
||||
hr = np.float32(hr)
|
||||
sr = np.float32(sr)
|
||||
diff = (sr - hr) / rgb_range
|
||||
gray_coeffs = np.array([65.738, 129.057, 25.064]
|
||||
).reshape((1, 3, 1, 1)) / 256
|
||||
diff = np.multiply(diff, gray_coeffs).sum(1)
|
||||
if hr.size == 1:
|
||||
return 0
|
||||
if scale != 1:
|
||||
shave = scale
|
||||
else:
|
||||
shave = scale + 6
|
||||
if scale == 1:
|
||||
valid = diff
|
||||
else:
|
||||
valid = diff[..., shave:-shave, shave:-shave]
|
||||
mse = np.mean(pow(valid, 2))
|
||||
return -10 * math.log10(mse)
|
||||
|
||||
|
||||
def rgb2ycbcr(img, y_only=True):
|
||||
'''metrics'''
|
||||
img.astype(np.float32)
|
||||
if y_only:
|
||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
||||
return rlt
|
|
@ -0,0 +1,67 @@
|
|||
'''temp'''
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
def set_template(args):
|
||||
'''temp'''
|
||||
if args.template.find('jpeg') >= 0:
|
||||
args.data_train = 'DIV2K_jpeg'
|
||||
args.data_test = 'DIV2K_jpeg'
|
||||
args.epochs = 200
|
||||
args.decay = '100'
|
||||
|
||||
if args.template.find('EDSR_paper') >= 0:
|
||||
args.model = 'EDSR'
|
||||
args.n_resblocks = 32
|
||||
args.n_feats = 256
|
||||
args.res_scale = 0.1
|
||||
|
||||
if args.template.find('MDSR') >= 0:
|
||||
args.model = 'MDSR'
|
||||
args.patch_size = 48
|
||||
args.epochs = 650
|
||||
|
||||
if args.template.find('DDBPN') >= 0:
|
||||
args.model = 'DDBPN'
|
||||
args.patch_size = 128
|
||||
args.scale = '4'
|
||||
|
||||
args.data_test = 'Set5'
|
||||
|
||||
args.batch_size = 20
|
||||
args.epochs = 1000
|
||||
args.decay = '500'
|
||||
args.gamma = 0.1
|
||||
args.weight_decay = 1e-4
|
||||
|
||||
args.loss = '1*MSE'
|
||||
|
||||
if args.template.find('GAN') >= 0:
|
||||
args.epochs = 200
|
||||
args.lr = 5e-5
|
||||
args.decay = '150'
|
||||
|
||||
if args.template.find('RCAN') >= 0:
|
||||
args.model = 'RCAN'
|
||||
args.n_resgroups = 10
|
||||
args.n_resblocks = 20
|
||||
args.n_feats = 64
|
||||
args.chop = True
|
||||
|
||||
if args.template.find('VDSR') >= 0:
|
||||
args.model = 'VDSR'
|
||||
args.n_resblocks = 20
|
||||
args.n_feats = 64
|
||||
args.patch_size = 41
|
||||
args.lr = 1e-1
|
Loading…
Reference in New Issue