diff --git a/model_zoo/research/cv/IPT/eval.py b/model_zoo/research/cv/IPT/eval.py new file mode 100755 index 00000000000..dafbb05f6b8 --- /dev/null +++ b/model_zoo/research/cv/IPT/eval.py @@ -0,0 +1,68 @@ +"""eval script""" +# Copyright 2021 Huawei Technologies Co., Ltd + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +from src import ipt +from src.args import args +from src.data.srdata import SRData +from src.metrics import calc_psnr, quantize + +from mindspore import context +import mindspore.dataset as de +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=0) + + +def main(): + """eval""" + for arg in vars(args): + if vars(args)[arg] == 'True': + vars(args)[arg] = True + elif vars(args)[arg] == 'False': + vars(args)[arg] = False + train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False) + train_de_dataset = de.GeneratorDataset(train_dataset, ['LR', "HR"], shuffle=False) + train_de_dataset = train_de_dataset.batch(1, drop_remainder=True) + train_loader = train_de_dataset.create_dict_iterator() + + net_m = ipt.IPT(args) + print('load mindspore net successfully.') + if args.pth_path: + param_dict = load_checkpoint(args.pth_path) + load_param_into_net(net_m, param_dict) + net_m.set_train(False) + num_imgs = train_de_dataset.get_dataset_size() + psnrs = np.zeros((num_imgs, 1)) + for batch_idx, imgs in enumerate(train_loader): + lr = imgs['LR'] + hr = imgs['HR'] + hr_np = np.float32(hr.asnumpy()) + pred = net_m.infrc(lr) + pred_np = np.float32(pred.asnumpy()) + pred_np = quantize(pred_np, 255) + psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True) + psnrs[batch_idx, 0] = psnr + if args.denoise: + print('Mean psnr of %s DN_%s is %.4f' % (args.data_test[0], args.sigma, psnrs.mean(axis=0)[0])) + elif args.derain: + print('Mean psnr of Derain is %.4f' % (psnrs.mean(axis=0))) + else: + print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0])) + + +if __name__ == '__main__': + print("Start main function!") + main() diff --git a/model_zoo/research/cv/IPT/image/ipt.png b/model_zoo/research/cv/IPT/image/ipt.png new file mode 100755 index 00000000000..cc79ab2bc89 Binary files /dev/null and b/model_zoo/research/cv/IPT/image/ipt.png differ diff --git a/model_zoo/research/cv/IPT/mindpsore_hub_conf.py b/model_zoo/research/cv/IPT/mindpsore_hub_conf.py new file mode 100755 index 00000000000..92b7d30d549 --- /dev/null +++ b/model_zoo/research/cv/IPT/mindpsore_hub_conf.py @@ -0,0 +1,26 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config.""" +from src.vitm import ViT + + +def IPT(*args, **kwargs): + return ViT(*args, **kwargs) + + +def create_network(name, *args, **kwargs): + if name == 'IPT': + return IPT(*args, **kwargs) + raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/research/cv/IPT/readme.md b/model_zoo/research/cv/IPT/readme.md new file mode 100755 index 00000000000..a8cf16993bf --- /dev/null +++ b/model_zoo/research/cv/IPT/readme.md @@ -0,0 +1,147 @@ + + +# Pre-Trained Image Processing Transformer (IPT) + +This repository is an official implementation of the paper "Pre-Trained Image Processing Transformer" from CVPR 2021. + +We study the low-level computer vision task (e.g., denoising, super-resolution and deraining) and develop a new pre-trained model, namely, image processing transformer (IPT). To maximally excavate the capability of transformer, we present to utilize the well-known ImageNet benchmark for generating a large amount of corrupted image pairs. The IPT model is trained on these images with multi-heads and multi-tails. In addition, the contrastive learning is introduced for well adapting to different image processing tasks. The pre-trained model can therefore efficiently employed on desired task after fine-tuning. With only one pre-trained model, IPT outperforms the current state-of-the-art methods on various low-level benchmarks. + +If you find our work useful in your research or publication, please cite our work: +[1] Hanting Chen, Yunhe Wang, Tianyu Guo, Chang Xu, Yiping Deng, Zhenhua Liu, Siwei Ma, Chunjing Xu, Chao Xu, and Wen Gao. **"Pre-trained image processing transformer"**. **CVPR 2021**. [[arXiv](https://arxiv.org/abs/2012.00364)] + + @inproceedings{chen2020pre, + title={Pre-trained image processing transformer}, + author={Chen, Hanting and Wang, Yunhe and Guo, Tianyu and Xu, Chang and Deng, Yiping and Liu, Zhenhua and Ma, Siwei and Xu, Chunjing and Xu, Chao and Gao, Wen}, + booktitle={CVPR}, + year={2021} + } + +## Model architecture +### The overall network architecture of IPT is shown as below: +![architecture](./ipt.png) + +## Dataset + +The benchmark datasets can be downloaded as follows: + +For super-resolution: + + Set5, + +[Set14](https://sites.google.com/site/romanzeyde/research-interests), + +[B100](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/), + +[Urban100](https://sites.google.com/site/jbhuang0604/publications/struct_sr). + +For denoising: + +[CBSD68](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/). + +For deraining: + +[Rain100L](https://www.icst.pku.edu.cn/struct/Projects/joint_rain_removal.html) + +The result images are converted into YCbCr color space. The PSNR is evaluated on the Y channel only. + +## Requirements + +### Hardware (GPU) +> Prepare hardware environment with GPU. + +### Framework +> [MindSpore](https://www.mindspore.cn/install/en) +### For more information, please check the resources below: +[MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) +[MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + +## Script Description + +> This is the inference script of IPT, you can following steps to finish the test of image processing tasks, like SR, denoise and derain, via the corresponding pretrained models. + +### Scripts and Sample Code + +``` +IPT +├── eval.py # inference entry +├── image +│   └── ipt.png # the illustration of IPT network +├── model +│   ├── IPT_denoise30.ckpt # denoise model weights for noise level 30 +│   ├── IPT_denoise50.ckpt # denoise model weights for noise level 50 +│   ├── IPT_derain.ckpt # derain model weights +│   ├── IPT_sr2.ckpt # X2 super-resolution model weights +│   ├── IPT_sr3.ckpt # X3 super-resolution model weights +│   └── IPT_sr4.ckpt # X4 super-resolution model weights +├── readme.md # Readme +├── scripts +│   └── run_eval.sh # inference script for all tasks +└── src + ├── args.py # options/hyper-parameters of IPT + ├── data + │   ├── common.py # common dataset + │   ├── __init__.py # Class data init function + │   └── srdata.py # flow of loading sr data + ├── foldunfold_stride.py # function of fold and unfold operations for images + ├── metrics.py # PSNR calculator + ├── template.py # setting of model selection + └── vitm.py # IPT network +``` + +### Script Parameter + +> For details about hyperparameters, see src/args.py. + +## Evaluation + +### Evaluation Process +> Inference example: +> For SR x4: + +``` +python eval.py --dir_data ../../data/ --data_test Set14 --nochange --test_only --ext img --chop_new --scale 4 --pth_path ./model/IPT_sr4.ckpt +``` + +> Or one can run following script for all tasks. + +``` +sh scripts/run_eval.sh +``` + +### Evaluation Result +The result are evaluated by the value of PSNR (Peak Signal-to-Noise Ratio), and the format is as following. + +``` +result: {"Mean psnr of Se5 x4 is 32.68"} +``` + +## Performance + +### Inference Performance + +The Results on all tasks are listed as below. + +Super-resolution results: + +| Scale | Set5 | Set14 | B100 | Urban100 | +| ----- | ----- | ----- | ----- | ----- | +| ×2 | 38.36 | 34.54 | 32.50 | 33.88 | +| ×3 | 34.83 | 30.96 | 29.39 | 29.59 | +| ×4 | 32.68 | 29.01 | 27.81 | 27.24 | + +Denoising results: + +| noisy level | CBSD68 | Urban100 | +| ----- | ----- | ----- | +| 30 | 32.37 | 33.82 | +| 50 | 29.94 | 31.56 | + +Derain results: + +| Task | Rain100L | +| ----- | ----- | +| Derain | 41.98 | + +## ModeZoo Homepage + +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). \ No newline at end of file diff --git a/model_zoo/research/cv/IPT/scripts/run_eval.sh b/model_zoo/research/cv/IPT/scripts/run_eval.sh new file mode 100755 index 00000000000..18ccea71ddf --- /dev/null +++ b/model_zoo/research/cv/IPT/scripts/run_eval.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export DEVICE_ID=$1 +DATA_DIR=$2 +DATA_SET=$3 +PATH_CHECKPOINT=$4 + +python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 4 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & +python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 3 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & +python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 2 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & + +##denoise +python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 30 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & +python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 50 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & + +##derain +python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --derain --derain_test 1 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & \ No newline at end of file diff --git a/model_zoo/research/cv/IPT/src/args.py b/model_zoo/research/cv/IPT/src/args.py new file mode 100755 index 00000000000..c83aa560a55 --- /dev/null +++ b/model_zoo/research/cv/IPT/src/args.py @@ -0,0 +1,239 @@ +'''args''' +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import argparse +from src import template + +parser = argparse.ArgumentParser(description='EDSR and MDSR') + +parser.add_argument('--debug', action='store_true', + help='Enables debug mode') +parser.add_argument('--template', default='.', + help='You can set various templates in option.py') + +# Hardware specifications +parser.add_argument('--n_threads', type=int, default=6, + help='number of threads for data loading') +parser.add_argument('--cpu', action='store_true', + help='use cpu only') +parser.add_argument('--n_GPUs', type=int, default=1, + help='number of GPUs') +parser.add_argument('--seed', type=int, default=1, + help='random seed') + +# Data specifications +parser.add_argument('--dir_data', type=str, default='/cache/data/', + help='dataset directory') +parser.add_argument('--dir_demo', type=str, default='../test', + help='demo image directory') +parser.add_argument('--data_train', type=str, default='DIV2K', + help='train dataset name') +parser.add_argument('--data_test', type=str, default='DIV2K', + help='test dataset name') +parser.add_argument('--data_range', type=str, default='1-800/801-810', + help='train/test data range') +parser.add_argument('--ext', type=str, default='sep', + help='dataset file extension') +parser.add_argument('--scale', type=str, default='4', + help='super resolution scale') +parser.add_argument('--patch_size', type=int, default=48, + help='output patch size') +parser.add_argument('--rgb_range', type=int, default=255, + help='maximum value of RGB') +parser.add_argument('--n_colors', type=int, default=3, + help='number of color channels to use') +parser.add_argument('--chop', action='store_true', + help='enable memory-efficient forward') +parser.add_argument('--no_augment', action='store_true', + help='do not use data augmentation') + +# Model specifications +parser.add_argument('--model', default='vtip', + help='model name') + +parser.add_argument('--act', type=str, default='relu', + help='activation function') +parser.add_argument('--pre_train', type=str, default='', + help='pre-trained model directory') +parser.add_argument('--extend', type=str, default='.', + help='pre-trained model directory') +parser.add_argument('--n_resblocks', type=int, default=16, + help='number of residual blocks') +parser.add_argument('--n_feats', type=int, default=64, + help='number of feature maps') +parser.add_argument('--res_scale', type=float, default=1, + help='residual scaling') +parser.add_argument('--shift_mean', default=True, + help='subtract pixel mean from the input') +parser.add_argument('--dilation', action='store_true', + help='use dilated convolution') +parser.add_argument('--precision', type=str, default='single', + choices=('single', 'half'), + help='FP precision for test (single | half)') + +# Option for Residual dense network (RDN) +parser.add_argument('--G0', type=int, default=64, + help='default number of filters. (Use in RDN)') +parser.add_argument('--RDNkSize', type=int, default=3, + help='default kernel size. (Use in RDN)') +parser.add_argument('--RDNconfig', type=str, default='B', + help='parameters config of RDN. (Use in RDN)') + +# Option for Residual channel attention network (RCAN) +parser.add_argument('--n_resgroups', type=int, default=10, + help='number of residual groups') +parser.add_argument('--reduction', type=int, default=16, + help='number of feature maps reduction') + +# Training specifications +parser.add_argument('--reset', action='store_true', + help='reset the training') +parser.add_argument('--test_every', type=int, default=1000, + help='do test per every N batches') +parser.add_argument('--epochs', type=int, default=300, + help='number of epochs to train') +parser.add_argument('--batch_size', type=int, default=16, + help='input batch size for training') +parser.add_argument('--test_batch_size', type=int, default=1, + help='input batch size for training') +parser.add_argument('--split_batch', type=int, default=1, + help='split the batch into smaller chunks') +parser.add_argument('--self_ensemble', action='store_true', + help='use self-ensemble method for test') +parser.add_argument('--test_only', action='store_true', + help='set this option to test the model') +parser.add_argument('--gan_k', type=int, default=1, + help='k value for adversarial loss') + +# Optimization specifications +parser.add_argument('--lr', type=float, default=1e-4, + help='learning rate') +parser.add_argument('--decay', type=str, default='200', + help='learning rate decay type') +parser.add_argument('--gamma', type=float, default=0.5, + help='learning rate decay factor for step decay') +parser.add_argument('--optimizer', default='ADAM', + choices=('SGD', 'ADAM', 'RMSprop'), + help='optimizer to use (SGD | ADAM | RMSprop)') +parser.add_argument('--momentum', type=float, default=0.9, + help='SGD momentum') +parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), + help='ADAM beta') +parser.add_argument('--epsilon', type=float, default=1e-8, + help='ADAM epsilon for numerical stability') +parser.add_argument('--weight_decay', type=float, default=0, + help='weight decay') +parser.add_argument('--gclip', type=float, default=0, + help='gradient clipping threshold (0 = no clipping)') + +# Loss specifications +parser.add_argument('--loss', type=str, default='1*L1', + help='loss function configuration') +parser.add_argument('--skip_threshold', type=float, default='1e8', + help='skipping batch that has large error') + +# Log specifications +parser.add_argument('--save', type=str, default='/cache/results/edsr_baseline_x2/', + help='file name to save') +parser.add_argument('--load', type=str, default='', + help='file name to load') +parser.add_argument('--resume', type=int, default=0, + help='resume from specific checkpoint') +parser.add_argument('--save_models', action='store_true', + help='save all intermediate models') +parser.add_argument('--print_every', type=int, default=100, + help='how many batches to wait before logging training status') +parser.add_argument('--save_results', action='store_true', + help='save output results') +parser.add_argument('--save_gt', action='store_true', + help='save low-resolution and high-resolution images together') + +parser.add_argument('--scalelr', type=int, default=0) +# cloud +parser.add_argument('--moxfile', type=int, default=1) +parser.add_argument('--imagenet', type=int, default=0) +parser.add_argument('--data_url', type=str, help='path to dataset') +parser.add_argument('--train_url', type=str, help='train_dir') +parser.add_argument('--pretrain', type=str, default='') +parser.add_argument('--pth_path', type=str, default='') +parser.add_argument('--load_query', type=int, default=0) +# transformer +parser.add_argument('--patch_dim', type=int, default=3) +parser.add_argument('--num_heads', type=int, default=12) +parser.add_argument('--num_layers', type=int, default=12) +parser.add_argument('--dropout_rate', type=float, default=0) +parser.add_argument('--no_norm', action='store_true') +parser.add_argument('--post_norm', action='store_true') +parser.add_argument('--no_mlp', action='store_true') +parser.add_argument('--test', action='store_true') +parser.add_argument('--chop_new', action='store_true') +parser.add_argument('--pos_every', action='store_true') +parser.add_argument('--no_pos', action='store_true') +parser.add_argument('--num_queries', type=int, default=6) +parser.add_argument('--reweight', action='store_true') + +# denoise +parser.add_argument('--denoise', action='store_true') +parser.add_argument('--sigma', type=float, default=25) + +# derain +parser.add_argument('--derain', action='store_true') +parser.add_argument('--finetune', action='store_true') +parser.add_argument('--derain_test', type=int, default=10) +# alltask +parser.add_argument('--alltask', action='store_true') + +# dehaze +parser.add_argument('--dehaze', action='store_true') +parser.add_argument('--dehaze_test', type=int, default=100) +parser.add_argument('--indoor', action='store_true') +parser.add_argument('--outdoor', action='store_true') +parser.add_argument('--nochange', action='store_true') +# deblur +parser.add_argument('--deblur', action='store_true') +parser.add_argument('--deblur_test', type=int, default=1000) + +# distribute +parser.add_argument('--init_method', type=str, + default=None, help='master address') +parser.add_argument('--rank', type=int, default=0, + help='Index of current task') +parser.add_argument('--world_size', type=int, default=1, + help='Total number of tasks') +parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') +parser.add_argument('--dist-url', default='', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--distribute', action='store_true') + +args, unparsed = parser.parse_known_args() +template.set_template(args) + +args.scale = [int(x) for x in args.scale.split("+")] +args.data_train = args.data_train.split('+') +args.data_test = args.data_test.split('+') + +if args.epochs == 0: + args.epochs = 1e8 + +for arg in vars(args): + if vars(args)[arg] == 'True': + vars(args)[arg] = True + elif vars(args)[arg] == 'False': + vars(args)[arg] = False diff --git a/model_zoo/research/cv/IPT/src/data/__init__.py b/model_zoo/research/cv/IPT/src/data/__init__.py new file mode 100755 index 00000000000..b298c60ebc2 --- /dev/null +++ b/model_zoo/research/cv/IPT/src/data/__init__.py @@ -0,0 +1,35 @@ +"""data""" +# Copyright 2021 Huawei Technologies Co., Ltd" +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from importlib import import_module + + +class Data: + """data""" + + def __init__(self, args): + self.loader_train = None + self.loader_test = [] + for d in args.data_test: + if d in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109', 'CBSD68', 'Rain100L', 'GOPRO_Large']: + m = import_module('data.benchmark') + testset = getattr(m, 'Benchmark')(args, train=False, name=d) + else: + module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' + m = import_module('data.' + module_name.lower()) + testset = getattr(m, module_name)(args, train=False, name=d) + + self.loader_test.append( + testset + ) diff --git a/model_zoo/research/cv/IPT/src/data/common.py b/model_zoo/research/cv/IPT/src/data/common.py new file mode 100755 index 00000000000..3155e1381bc --- /dev/null +++ b/model_zoo/research/cv/IPT/src/data/common.py @@ -0,0 +1,93 @@ +"""common""" +# Copyright 2021 Huawei Technologies Co., Ltd + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import random + +import numpy as np +import skimage.color as sc + + +def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): + """common""" + ih, iw = args[0].shape[:2] + + tp = patch_size + ip = tp // scale + + ix = random.randrange(0, iw - ip + 1) + iy = random.randrange(0, ih - ip + 1) + + if not input_large: + tx, ty = scale * ix, scale * iy + else: + tx, ty = ix, iy + + ret = [ + args[0][iy:iy + ip, ix:ix + ip, :], + *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] + ] + + return ret + + +def set_channel(*args, n_channels=3): + """common""" + + def _set_channel(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + + c = img.shape[2] + if n_channels == 1 and c == 3: + img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) + elif n_channels == 3 and c == 1: + img = np.concatenate([img] * n_channels, 2) + + return img[:, :, :n_channels] + + return [_set_channel(a) for a in args] + + +def np2Tensor(*args, rgb_range=255): + """common""" + + def _np2Tensor(img): + np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) + tensor = np_transpose.astype(np.float32) + tensor = tensor * (rgb_range / 255) + # tensor = torch.from_numpy(np_transpose).float() + # tensor.mul_(rgb_range / 255) + + return tensor + + return [_np2Tensor(a) for a in args] + + +def augment(*args, hflip=True, rot=True): + """common""" + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(a) for a in args] diff --git a/model_zoo/research/cv/IPT/src/data/srdata.py b/model_zoo/research/cv/IPT/src/data/srdata.py new file mode 100755 index 00000000000..ab31718d0a1 --- /dev/null +++ b/model_zoo/research/cv/IPT/src/data/srdata.py @@ -0,0 +1,301 @@ +"""srdata""" +# Copyright 2021 Huawei Technologies Co., Ltd + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import glob +import random +import pickle + +from src.data import common + +import numpy as np +import imageio + + +def search(root, target="JPEG"): + """srdata""" + item_list = [] + items = os.listdir(root) + for item in items: + path = os.path.join(root, item) + if os.path.isdir(path): + item_list.extend(search(path, target)) + elif path.split('/')[-1].startswith(target): + item_list.append(path) + elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]): + item_list.append(path) + else: + item_list = [] + return item_list + + +def search_dehaze(root, target="JPEG"): + """srdata""" + item_list = [] + items = os.listdir(root) + for item in items: + path = os.path.join(root, item) + if os.path.isdir(path): + extend_list = search_dehaze(path, target) + if extend_list is not None: + item_list.extend(extend_list) + elif path.split('/')[-2].endswith(target): + item_list.append(path) + return item_list + + +class SRData(): + """srdata""" + + def __init__(self, args, name='', train=True, benchmark=False): + self.args = args + self.name = name + self.train = train + self.split = 'train' if train else 'test' + self.do_eval = True + self.benchmark = benchmark + self.input_large = (args.model == 'VDSR') + self.scale = args.scale + self.idx_scale = 0 + + if self.args.derain: + self.derain_test = os.path.join(args.dir_data, "Rain100L") + self.derain_lr_test = search(self.derain_test, "rain") + self.derain_hr_test = [path.replace( + "rainy/", "no") for path in self.derain_lr_test] + self._set_filesystem(args.dir_data) + if args.ext.find('img') < 0: + path_bin = os.path.join(self.apath, 'bin') + os.makedirs(path_bin, exist_ok=True) + + list_hr, list_lr = self._scan() + if args.ext.find('img') >= 0 or benchmark: + self.images_hr, self.images_lr = list_hr, list_lr + elif args.ext.find('sep') >= 0: + os.makedirs( + self.dir_hr.replace(self.apath, path_bin), + exist_ok=True + ) + for s in self.scale: + if s == 1: + os.makedirs( + os.path.join(self.dir_hr), + exist_ok=True + ) + else: + os.makedirs( + os.path.join( + self.dir_lr.replace(self.apath, path_bin), + 'X{}'.format(s) + ), + exist_ok=True + ) + + self.images_hr, self.images_lr = [], [[] for _ in self.scale] + for h in list_hr: + b = h.replace(self.apath, path_bin) + b = b.replace(self.ext[0], '.pt') + self.images_hr.append(b) + self._check_and_load(args.ext, h, b, verbose=True) + for i, ll in enumerate(list_lr): + for l in ll: + b = l.replace(self.apath, path_bin) + b = b.replace(self.ext[1], '.pt') + self.images_lr[i].append(b) + self._check_and_load(args.ext, l, b, verbose=True) + + # Below functions as used to prepare images + def _scan(self): + """srdata""" + names_hr = sorted( + glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) + ) + names_lr = [[] for _ in self.scale] + for f in names_hr: + filename, _ = os.path.splitext(os.path.basename(f)) + for si, s in enumerate(self.scale): + if s != 1: + scale = s + names_lr[si].append(os.path.join( + self.dir_lr, 'X{}/{}x{}{}'.format( + s, filename, scale, self.ext[1] + ) + )) + for si, s in enumerate(self.scale): + if s == 1: + names_lr[si] = names_hr + return names_hr, names_lr + + def _set_filesystem(self, dir_data): + self.apath = os.path.join(dir_data, self.name[0]) + self.dir_hr = os.path.join(self.apath, 'HR') + self.dir_lr = os.path.join(self.apath, 'LR_bicubic') + self.ext = ('.png', '.png') + + def _check_and_load(self, ext, img, f, verbose=True): + if not os.path.isfile(f) or ext.find('reset') >= 0: + if verbose: + print('Making a binary: {}'.format(f)) + with open(f, 'wb') as _f: + pickle.dump(imageio.imread(img), _f) + + def __getitem__(self, idx): + if self.args.model == 'vtip' and self.args.derain and self.scale[ + self.idx_scale] == 1 and not self.args.finetune: + norain, rain, _ = self._load_rain_test(idx) + pair = common.set_channel( + *[rain, norain], n_channels=self.args.n_colors) + pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) + return pair_t[0], pair_t[1] + if self.args.model == 'vtip' and self.args.denoise and self.scale[self.idx_scale] == 1: + hr, _ = self._load_file_hr(idx) + pair = self.get_patch_hr(hr) + pair = common.set_channel(*[pair], n_channels=self.args.n_colors) + pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) + noise = np.random.randn(*pair_t[0].shape) * self.args.sigma + lr = pair_t[0] + noise + lr = np.float32(np.clip(lr, 0, 255)) + return lr, pair_t[0] + lr, hr, _ = self._load_file(idx) + pair = self.get_patch(lr, hr) + pair = common.set_channel(*pair, n_channels=self.args.n_colors) + pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) + + return pair_t[0], pair_t[1] + + def __len__(self): + if self.train: + return len(self.images_hr) * self.repeat + + if self.args.derain and not self.args.alltask: + return int(len(self.derain_hr_test) / self.args.derain_test) + return len(self.images_hr) + + def _get_index(self, idx): + """srdata""" + if self.train: + return idx % len(self.images_hr) + return idx + + def _load_file_hr(self, idx): + """srdata""" + idx = self._get_index(idx) + f_hr = self.images_hr[idx] + + filename, _ = os.path.splitext(os.path.basename(f_hr)) + if self.args.ext == 'img' or self.benchmark: + hr = imageio.imread(f_hr) + elif self.args.ext.find('sep') >= 0: + with open(f_hr, 'rb') as _f: + hr = pickle.load(_f) + + return hr, filename + + def _load_rain(self, idx, rain_img=False): + """srdata""" + idx = random.randint(0, len(self.derain_img_list) - 1) + f_lr = self.derain_img_list[idx] + if rain_img: + norain = imageio.imread(f_lr.replace("rainstreak", "norain")) + rain = imageio.imread(f_lr.replace("rainstreak", "rain")) + return norain, rain + lr = imageio.imread(f_lr) + return lr + + def _load_rain_test(self, idx): + """srdata""" + f_hr = self.derain_hr_test[idx] + f_lr = self.derain_lr_test[idx] + filename, _ = os.path.splitext(os.path.basename(f_lr)) + norain = imageio.imread(f_hr) + rain = imageio.imread(f_lr) + return norain, rain, filename + + def _load_denoise(self, idx): + """srdata""" + idx = self._get_index(idx) + f_lr = self.images_hr[idx] + norain = imageio.imread(f_lr) + rain = imageio.imread(f_lr.replace("HR", "LR_bicubic")) + return norain, rain + + def _load_file(self, idx): + """srdata""" + idx = self._get_index(idx) + + f_hr = self.images_hr[idx] + f_lr = self.images_lr[self.idx_scale][idx] + + filename, _ = os.path.splitext(os.path.basename(f_hr)) + if self.args.ext == 'img' or self.benchmark: + hr = imageio.imread(f_hr) + lr = imageio.imread(f_lr) + elif self.args.ext.find('sep') >= 0: + with open(f_hr, 'rb') as _f: + hr = pickle.load(_f) + with open(f_lr, 'rb') as _f: + lr = pickle.load(_f) + + return lr, hr, filename + + def get_patch_hr(self, hr): + """srdata""" + if self.train: + hr = self.get_patch_img_hr( + hr, + patch_size=self.args.patch_size, + scale=1 + ) + + return hr + + def get_patch_img_hr(self, img, patch_size=96, scale=2): + """srdata""" + ih, iw = img.shape[:2] + + tp = patch_size + ip = tp // scale + + ix = random.randrange(0, iw - ip + 1) + iy = random.randrange(0, ih - ip + 1) + + ret = img[iy:iy + ip, ix:ix + ip, :] + + return ret + + def get_patch(self, lr, hr): + """srdata""" + scale = self.scale[self.idx_scale] + if self.train: + lr, hr = common.get_patch( + lr, hr, + patch_size=self.args.patch_size * scale, + scale=scale, + multi=(len(self.scale) > 1) + ) + if not self.args.no_augment: + lr, hr = common.augment(lr, hr) + else: + ih, iw = lr.shape[:2] + hr = hr[0:ih * scale, 0:iw * scale] + + return lr, hr + + def set_scale(self, idx_scale): + """srdata""" + if not self.input_large: + self.idx_scale = idx_scale + else: + self.idx_scale = random.randint(0, len(self.scale) - 1) diff --git a/model_zoo/research/cv/IPT/src/foldunfold_stride.py b/model_zoo/research/cv/IPT/src/foldunfold_stride.py new file mode 100755 index 00000000000..69b28655fad --- /dev/null +++ b/model_zoo/research/cv/IPT/src/foldunfold_stride.py @@ -0,0 +1,241 @@ +'''stride''' +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +from mindspore import nn +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + + +class _stride_unfold_(nn.Cell): + '''stride''' + + def __init__(self, + kernel_size, + stride=-1): + + super(_stride_unfold_, self).__init__() + if stride == -1: + self.stride = kernel_size + else: + self.stride = stride + self.kernel_size = kernel_size + + self.reshape = P.Reshape() + self.transpose = P.Transpose() + self.unfold = _unfold_(kernel_size) + + def construct(self, x): + """stride""" + N, C, H, W = x.shape + leftup_idx_x = [] + leftup_idx_y = [] + nh = int(H / self.stride) + nw = int(W / self.stride) + for i in range(nh): + leftup_idx_x.append(i * self.stride) + for i in range(nw): + leftup_idx_y.append(i * self.stride) + NumBlock_x = len(leftup_idx_x) + NumBlock_y = len(leftup_idx_y) + zeroslike = P.ZerosLike() + cc_2 = P.Concat(axis=2) + cc_3 = P.Concat(axis=3) + unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size, + NumBlock_y * self.kernel_size), mstype.float32) + N, C, H, W = unf_x.shape + for i in range(NumBlock_x): + for j in range(NumBlock_y): + unf_i = i * self.kernel_size + unf_j = j * self.kernel_size + org_i = leftup_idx_x[i] + org_j = leftup_idx_y[j] + fills = x[:, :, org_i:org_i + self.kernel_size, + org_j:org_j + self.kernel_size] + unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]), cc_2((cc_2( + (zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)), zeroslike( + unf_x[:, :, unf_i + self.kernel_size:, unf_j:unf_j + self.kernel_size]))))), + zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:]))) + y = self.unfold(unf_x) + return y + + +class _stride_fold_(nn.Cell): + '''stride''' + + def __init__(self, + kernel_size, + output_shape=(-1, -1), + stride=-1): + + super(_stride_fold_, self).__init__() + if isinstance(kernel_size, (list, tuple)): + self.kernel_size = kernel_size + else: + self.kernel_size = [kernel_size, kernel_size] + + if stride == -1: + self.stride = kernel_size[0] + else: + self.stride = stride + + self.output_shape = output_shape + + self.reshape = P.Reshape() + self.transpose = P.Transpose() + self.fold = _fold_(kernel_size) + + def construct(self, x): + '''stride''' + if self.output_shape[0] == -1: + large_x = self.fold(x) + N, C, H, _ = large_x.shape + leftup_idx = [] + for i in range(0, H, self.kernel_size[0]): + leftup_idx.append(i) + NumBlock = len(leftup_idx) + fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0], + (NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32) + + for i in range(NumBlock): + for j in range(NumBlock): + fold_i = i * self.stride + fold_j = j * self.stride + org_i = leftup_idx[i] + org_j = leftup_idx[j] + fills = x[:, :, org_i:org_i + self.kernel_size[0], + org_j:org_j + self.kernel_size[1]] + fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2( + (zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike( + fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), + zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) + y = fold_x + else: + NumBlock_x = int( + (self.output_shape[0] - self.kernel_size[0]) / self.stride + 1) + NumBlock_y = int( + (self.output_shape[1] - self.kernel_size[1]) / self.stride + 1) + large_shape = [NumBlock_x * self.kernel_size[0], + NumBlock_y * self.kernel_size[1]] + self.fold = _fold_(self.kernel_size, large_shape) + large_x = self.fold(x) + N, C, H, _ = large_x.shape + leftup_idx_x = [] + leftup_idx_y = [] + for i in range(NumBlock_x): + leftup_idx_x.append(i * self.kernel_size[0]) + for i in range(NumBlock_y): + leftup_idx_y.append(i * self.kernel_size[1]) + fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], + (NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32) + for i in range(NumBlock_x): + for j in range(NumBlock_y): + fold_i = i * self.stride + fold_j = j * self.stride + org_i = leftup_idx_x[i] + org_j = leftup_idx_y[j] + fills = x[:, :, org_i:org_i + self.kernel_size[0], + org_j:org_j + self.kernel_size[1]] + fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2( + (zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike( + fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), + zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) + y = fold_x + return y + + +class _unfold_(nn.Cell): + '''stride''' + + def __init__(self, + kernel_size, + stride=-1): + + super(_unfold_, self).__init__() + if stride == -1: + self.stride = kernel_size + self.kernel_size = kernel_size + + self.reshape = P.Reshape() + self.transpose = P.Transpose() + + def construct(self, x): + '''stride''' + N, C, H, W = x.shape + numH = int(H / self.kernel_size) + numW = int(W / self.kernel_size) + if numH * self.kernel_size != H or numW * self.kernel_size != W: + x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size] + output_img = self.reshape(x, (N, C, numH, self.kernel_size, W)) + + output_img = self.transpose(output_img, (0, 1, 2, 4, 3)) + + output_img = self.reshape(output_img, (N, C, int( + numH * numW), self.kernel_size, self.kernel_size)) + + output_img = self.transpose(output_img, (0, 2, 1, 4, 3)) + + output_img = self.reshape(output_img, (N, int(numH * numW), -1)) + return output_img + + +class _fold_(nn.Cell): + '''stride''' + + def __init__(self, + kernel_size, + output_shape=(-1, -1), + stride=-1): + + super(_fold_, self).__init__() + + if isinstance(kernel_size, (list, tuple)): + self.kernel_size = kernel_size + else: + self.kernel_size = [kernel_size, kernel_size] + + if stride == -1: + self.stride = kernel_size[0] + self.output_shape = output_shape + + self.reshape = P.Reshape() + self.transpose = P.Transpose() + + def construct(self, x): + '''stride''' + N, C, L = x.shape + org_C = int(L / self.kernel_size[0] / self.kernel_size[1]) + if self.output_shape[0] == -1: + numH = int(np.sqrt(C)) + numW = int(np.sqrt(C)) + org_H = int(numH * self.kernel_size[0]) + org_W = org_H + else: + org_H = int(self.output_shape[0]) + org_W = int(self.output_shape[1]) + numH = int(org_H / self.kernel_size[0]) + numW = int(org_W / self.kernel_size[1]) + + output_img = self.reshape( + x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1])) + + output_img = self.transpose(output_img, (0, 2, 1, 3, 4)) + output_img = self.reshape( + output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1])) + + output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5)) + + output_img = self.reshape(output_img, (N, org_C, org_H, org_W)) + return output_img diff --git a/model_zoo/research/cv/IPT/src/ipt.py b/model_zoo/research/cv/IPT/src/ipt.py new file mode 100755 index 00000000000..8f51a974f07 --- /dev/null +++ b/model_zoo/research/cv/IPT/src/ipt.py @@ -0,0 +1,1023 @@ +"""ipt""" +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import math +import copy +import numpy as np +from mindspore import nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter + + +# from mindspore.ops.primitive import constexpr +# import IPython + +class MultiheadAttention(nn.Cell): + """ + Apply multi-headed attention from "from_tensor" to "to_tensor". + + Args: + batch_size (int): Batch size of input datasets. + from_tensor_width (int): Size of last dim of from_tensor. + to_tensor_width (int): Size of last dim of to_tensor. + from_seq_length (int): Length of from_tensor sequence. + to_seq_length (int): Length of to_tensor sequence. + num_attention_heads (int): Number of attention heads. Default: 1. + size_per_head (int): Size of each attention head. Default: 512. + query_act (str): Activation function for the query transform. Default: None. + key_act (str): Activation function for the key transform. Default: None. + value_act (str): Activation function for the value transform. Default: None. + has_attention_mask (bool): Specifies whether to use attention mask. Default: False. + attention_probs_dropout_prob (float): The dropout probability for + MultiheadAttention. Default: 0.0. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d + tensor. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32. + """ + + def __init__(self, + q_tensor_width, + k_tensor_width, + v_tensor_width, + hidden_width, + out_tensor_width, + num_attention_heads=1, + query_act=None, + key_act=None, + value_act=None, + out_act=None, + has_attention_mask=True, + attention_probs_dropout_prob=0.0, + use_one_hot_embeddings=False, + initializer_range=0.02, + do_return_2d_tensor=False, + compute_type=mstype.float32, + same_dim=True): + super(MultiheadAttention, self).__init__() + self.num_attention_heads = num_attention_heads + self.size_per_head = int(hidden_width / num_attention_heads) + self.has_attention_mask = has_attention_mask + assert has_attention_mask + self.use_one_hot_embeddings = use_one_hot_embeddings + self.initializer_range = initializer_range + self.do_return_2d_tensor = do_return_2d_tensor + self.same_dim = same_dim + + self.scores_mul = Tensor( + [1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) + self.reshape = P.Reshape() + self.shape_q_2d = (-1, q_tensor_width) + self.shape_k_2d = (-1, k_tensor_width) + self.shape_v_2d = (-1, v_tensor_width) + self.hidden_width = hidden_width + # units = num_attention_heads * self.size_per_head + if self.same_dim: + self.in_proj_layer = \ + Parameter(Tensor(np.random.rand(hidden_width * 3, + q_tensor_width), dtype=compute_type), name="weight") + else: + self.query_layer = nn.Dense(q_tensor_width, + hidden_width, + activation=query_act, + has_bias=False).to_float(compute_type) + self.key_layer = nn.Dense(k_tensor_width, + hidden_width, + activation=key_act, + has_bias=False).to_float(compute_type) + self.value_layer = nn.Dense(q_tensor_width, + hidden_width, + activation=value_act, + has_bias=False).to_float(compute_type) + self.out_proj = nn.Dense(hidden_width, + out_tensor_width, + activation=out_act, + has_bias=False).to_float(compute_type) + + self.matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.multiply = P.Mul() + self.transpose = P.Transpose() + self.trans_shape = (0, 2, 1, 3) + self.trans_shape_relative = (2, 0, 1, 3) + self.trans_shape_position = (1, 2, 0, 3) + self.multiply_data = Tensor([-10000.0,], dtype=compute_type) + self.matmul = P.BatchMatMul() + + self.softmax = nn.Softmax() + self.dropout = nn.Dropout(1. - attention_probs_dropout_prob) + self.use_dropout = attention_probs_dropout_prob > 0 + + if self.has_attention_mask: + self.expand_dims = P.ExpandDims() + self.sub = P.Sub() + self.add = P.TensorAdd() + self.cast = P.Cast() + self.get_dtype = P.DType() + + self.softmax_cast = P.Cast() + self.matmul_dense = P.MatMul(transpose_b=True) + self.split = P.Split(0, 3) + + def construct(self, tensor_q, tensor_k, tensor_v, batch_size, seq_length, attention_mask=None): + """Apply multihead attention.""" + self.batch_size = batch_size + shape_qkv = (self.batch_size, -1, + self.num_attention_heads, self.size_per_head) + shape_linear = (self.batch_size * seq_length, + self.num_attention_heads * self.size_per_head) + if self.do_return_2d_tensor: + shape_return = (self.batch_size * seq_length, + self.num_attention_heads * self.size_per_head) + if seq_length == -1: + shape_return = (-1, self.num_attention_heads * + self.size_per_head) + else: + shape_return = (self.batch_size, seq_length, + self.num_attention_heads * self.size_per_head) + + tensor_q_2d = self.reshape(tensor_q, self.shape_q_2d) + tensor_k_2d = self.reshape(tensor_k, self.shape_k_2d) + tensor_v_2d = self.reshape(tensor_v, self.shape_v_2d) + + if P.Equal()(tensor_q_2d, tensor_v_2d)[0][0]: + x = self.matmul_dense(self.in_proj_layer, tensor_q_2d) + query_out, key_out, value_out = self.split(x) + + elif self.same_dim: + _start = int(0) + _end = int(self.hidden_width) + _w = self.in_proj_layer[_start:_end, :] + # _b = None + query_out = self.matmul_dense(_w, tensor_q_2d) + + _start = int(self.hidden_width) + _end = int(self.hidden_width * 2) + _w = self.in_proj_layer[_start:_end, :] + # _b = None + key_out = self.matmul_dense(_w, tensor_k_2d) + + _start = int(self.hidden_width * 2) + _end = None + _w = self.in_proj_layer[_start:] + # _b = None + value_out = self.matmul_dense(_w, tensor_v_2d) + else: + query_out = self.query_layer(tensor_q_2d) + key_out = self.key_layer(tensor_k_2d) + value_out = self.value_layer(tensor_v_2d) + query_out = self.transpose(query_out, (1, 0)) + key_out = self.transpose(key_out, (1, 0)) + value_out = self.transpose(value_out, (1, 0)) + query_layer = self.reshape(query_out, shape_qkv) + query_layer = self.transpose(query_layer, self.trans_shape) + key_layer = self.reshape(key_out, shape_qkv) + key_layer = self.transpose(key_layer, self.trans_shape) + + attention_scores = self.matmul_trans_b(query_layer, key_layer) + attention_scores = self.multiply(attention_scores, self.scores_mul) + + attention_scores = self.softmax_cast(attention_scores, mstype.float32) + attention_probs = self.softmax(attention_scores) + attention_probs = self.softmax_cast( + attention_probs, self.get_dtype(key_layer)) + if self.use_dropout: + attention_probs = self.dropout(attention_probs) + + value_layer = self.reshape(value_out, shape_qkv) + value_layer = self.transpose(value_layer, self.trans_shape) + context_layer = self.matmul(attention_probs, value_layer) + + context_layer = self.transpose(context_layer, self.trans_shape) + context_layer = self.reshape(context_layer, shape_linear) + + context_layer = self.out_proj(context_layer) + context_layer = self.reshape(context_layer, shape_return) + return context_layer + + +class TransformerEncoderLayer(nn.Cell): + """ipt""" + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu"): + super().__init__() + + self.self_attn = MultiheadAttention(q_tensor_width=d_model, + k_tensor_width=d_model, + v_tensor_width=d_model, + hidden_width=d_model, + out_tensor_width=d_model, + num_attention_heads=nhead, + attention_probs_dropout_prob=dropout) + self.linear1 = nn.Dense(d_model, dim_feedforward) + self.dropout = nn.Dropout(1. - dropout) + self.linear2 = nn.Dense(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm([d_model]) + self.norm2 = nn.LayerNorm([d_model]) + self.dropout1 = nn.Dropout(1. - dropout) + self.dropout2 = nn.Dropout(1. - dropout) + self.reshape = P.Reshape() + + self.activation = P.ReLU() + + def with_pos_embed(self, tensor, pos): + """ipt""" + return tensor if pos is None else tensor + pos + + def construct(self, src, pos=None): + """ipt""" + b, n, d = src.shape + permute_linear = (b * n, d) + permute_recover = (b, n, d) + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, src2, batch_size=b, seq_length=n) + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.reshape(src2, permute_linear) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src2 = self.reshape(src2, permute_recover) + src = src + self.dropout2(src2) + return src + + +class TransformerDecoderLayer(nn.Cell): + """ipt""" + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu"): + super().__init__() + self.self_attn = MultiheadAttention(q_tensor_width=d_model, + k_tensor_width=d_model, + v_tensor_width=d_model, + hidden_width=d_model, + out_tensor_width=d_model, + num_attention_heads=nhead, + attention_probs_dropout_prob=dropout) + self.multihead_attn = MultiheadAttention(q_tensor_width=d_model, + k_tensor_width=d_model, + v_tensor_width=d_model, + hidden_width=d_model, + out_tensor_width=d_model, + num_attention_heads=nhead, + attention_probs_dropout_prob=dropout) + self.linear1 = nn.Dense(d_model, dim_feedforward) + self.dropout = nn.Dropout(1. - dropout) + self.linear2 = nn.Dense(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm([d_model]) + self.norm2 = nn.LayerNorm([d_model]) + self.norm3 = nn.LayerNorm([d_model]) + self.dropout1 = nn.Dropout(1. - dropout) + self.dropout2 = nn.Dropout(1. - dropout) + self.dropout3 = nn.Dropout(1. - dropout) + self.reshape = P.Reshape() + self.activation = P.ReLU() + + def with_pos_embed(self, tensor, pos): + """ipt""" + return tensor if pos is None else tensor + pos + + def construct(self, tgt, memory, pos=None, query_pos=None): + """ipt""" + b, n, d = tgt.shape + permute_linear = (b * n, d) + permute_recover = (b, n, d) + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, tensor_v=tgt2, batch_size=b, seq_length=n) + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(tensor_q=self.with_pos_embed(tgt2, query_pos), + tensor_k=self.with_pos_embed(memory, pos), + tensor_v=memory, + batch_size=b, seq_length=n) + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.reshape(tgt2, permute_linear) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt2 = self.reshape(tgt2, permute_recover) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class TransformerEncoder(nn.Cell): + """ipt""" + + def __init__(self, encoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + def construct(self, src, pos=None): + """ipt""" + output = src + + for layer in self.layers: + output = layer(output, pos=pos) + + return output + + +class TransformerDecoder(nn.Cell): + """ipt""" + + def __init__(self, decoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + + def construct(self, tgt, memory, pos=None, query_pos=None): + """ipt""" + output = tgt + + for layer in self.layers: + output = layer(output, memory, pos=pos, query_pos=query_pos) + return output + + +def _get_clones(module, N): + """ipt""" + return nn.CellList([copy.deepcopy(module) for i in range(N)]) + + +class LearnedPositionalEncoding(nn.Cell): + """ipt""" + + def __init__(self, max_position_embeddings, embedding_dim, seq_length): + super(LearnedPositionalEncoding, self).__init__() + self.pe = nn.Embedding( + max_position_embeddings, embedding_dim) + self.seq_length = seq_length + + self.position_ids = Tensor(np.arange(self.seq_length).astype(np.int32)) + self.reshape = P.Reshape() + self.position_ids = self.reshape( + self.position_ids, (1, self.seq_length)) + + def construct(self, x, position_ids=None): + """ipt""" + if position_ids is None: + position_ids = self.position_ids[:, : self.seq_length] + + position_embeddings = self.pe(position_ids) + return position_embeddings + + +class VisionTransformer(nn.Cell): + """ipt""" + + def __init__( + self, + img_dim, + patch_dim, + num_channels, + embedding_dim, + num_heads, + num_layers, + hidden_dim, + num_queries, + positional_encoding_type="learned", + dropout_rate=0, + norm=False, + mlp=False, + pos_every=False, + no_pos=False + ): + super(VisionTransformer, self).__init__() + + assert embedding_dim % num_heads == 0 + assert img_dim % patch_dim == 0 + self.norm = norm + self.mlp = mlp + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.patch_dim = patch_dim + self.num_channels = num_channels + + self.img_dim = img_dim + self.pos_every = pos_every + self.num_patches = int((img_dim // patch_dim) ** 2) + self.seq_length = self.num_patches + self.flatten_dim = patch_dim * patch_dim * num_channels + + self.out_dim = patch_dim * patch_dim * num_channels + + self.no_pos = no_pos + + self.unf = _unfold_(patch_dim) + self.fold = _fold_(patch_dim) + + if self.mlp is not True: + self.linear_encoding = nn.Dense( + self.flatten_dim, embedding_dim) + self.mlp_head = nn.SequentialCell( + nn.Dense(embedding_dim, hidden_dim), + nn.Dropout(1. - dropout_rate), + nn.ReLU(), + nn.Dense(hidden_dim, self.out_dim), + nn.Dropout(1. - dropout_rate) + ) + + self.query_embed = nn.Embedding( + num_queries, embedding_dim * self.seq_length) + + encoder_layer = TransformerEncoderLayer( + embedding_dim, num_heads, hidden_dim, dropout_rate) + self.encoder = TransformerEncoder(encoder_layer, num_layers) + + decoder_layer = TransformerDecoderLayer( + embedding_dim, num_heads, hidden_dim, dropout_rate) + self.decoder = TransformerDecoder(decoder_layer, num_layers) + + self.reshape = P.Reshape() + self.tile = P.Tile() + self.transpose = P.Transpose() + if not self.no_pos: + self.position_encoding = LearnedPositionalEncoding( + self.seq_length, self.embedding_dim, self.seq_length + ) + + self.dropout_layer1 = nn.Dropout(1. - dropout_rate) + + def construct(self, x, query_idx): + """ipt""" + B, _, _, _ = x.shape + x = self.unf(x) + B, N, _ = x.shape + + if self.mlp is not True: + x = self.reshape(x, (int(B * N), -1)) + x = self.dropout_layer1(self.linear_encoding(x)) + x + x = self.reshape(x, (B, N, -1)) + query_embed = self.tile( + self.reshape(self.query_embed.embedding_table[int( + query_idx)], (1, self.seq_length, self.embedding_dim)), + (B, 1, 1)) + + if not self.no_pos: + pos = self.position_encoding(x) + + x = self.encoder(x + pos) + x = self.decoder(x, x, query_pos=query_embed) + + if self.mlp is not True: + x = self.reshape(x, (int(B * N), -1)) + x = self.mlp_head(x) + x + x = self.reshape(x, (B, N, -1)) + x = self.fold(x) + + return x + + +def default_conv(in_channels, out_channels, kernel_size, has_bias=True): + """ipt""" + return nn.Conv2d( + in_channels, out_channels, kernel_size, has_bias=has_bias) + + +class MeanShift(nn.Conv2d): + """ipt""" + + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): + super(MeanShift, self).__init__(3, 3, kernel_size=1) + self.reshape = P.Reshape() + self.eye = P.Eye() + std = Tensor(rgb_std, mstype.float32) + self.weight.set_data( + self.reshape(self.eye(3, 3, mstype.float32), (3, 3, 1, 1)) / self.reshape(std, (3, 1, 1, 1))) + self.weight.requires_grad = False + self.bias = Parameter( + sign * rgb_range * Tensor(rgb_mean, mstype.float32) / std, name='bias', requires_grad=False) + self.has_bias = True + + +class ResBlock(nn.Cell): + """ipt""" + + def __init__( + self, conv, n_feats, kernel_size, + bias=True, bn=False, act=nn.ReLU(), res_scale=1): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, has_bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + self.body = nn.SequentialCell(*m) + self.res_scale = res_scale + + self.mul = P.Mul() + + def construct(self, x): + """ipt""" + res = self.mul(self.body(x), self.res_scale) + res += x + + return res + + +def _pixelsf_(x, scale): + """ipt""" + N, C, iH, iW = x.shape + oH = int(iH * scale) + oW = int(iW * scale) + oC = int(C // (scale ** 2)) + + output = P.Reshape()(x, (N, oC, scale, scale, iH, iW)) + + output = P.Transpose()(output, (0, 1, 5, 3, 4, 2)) + + output = P.Reshape()(output, (N, oC, oH, oW)) + + output = P.Transpose()(output, (0, 1, 3, 2)) + + return output + + +class SmallUpSampler(nn.Cell): + """ipt""" + + def __init__(self, conv, upsize, n_feats, bn=False, act=False, bias=True): + super(SmallUpSampler, self).__init__() + self.conv = conv(n_feats, upsize * upsize * n_feats, 3, bias) + self.reshape = P.Reshape() + self.upsize = upsize + + def construct(self, x): + """ipt""" + x = self.conv(x) + output = _pixelsf_(x, self.upsize) + return output + + +class Upsampler(nn.Cell): + """ipt""" + + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + super(Upsampler, self).__init__() + m = [] + if (scale & (scale - 1)) == 0: + for _ in range(int(math.log(scale, 2))): + m.append(SmallUpSampler(conv, 2, n_feats, bias=bias)) + + elif scale == 3: + m.append(SmallUpSampler(conv, 3, n_feats, bias=bias)) + self.net = nn.SequentialCell(m) + + def construct(self, x): + """ipt""" + return self.net(x) + + +class IPT(nn.Cell): + """ipt""" + + def __init__(self, args, conv=default_conv): + super(IPT, self).__init__() + + self.scale_idx = 0 + + self.args = args + + n_feats = args.n_feats + kernel_size = 3 + act = nn.ReLU() + + self.sub_mean = MeanShift(args.rgb_range) + self.add_mean = MeanShift(args.rgb_range, sign=1) + + self.head = nn.CellList([ + nn.SequentialCell( + conv(args.n_colors, n_feats, kernel_size), + ResBlock(conv, n_feats, 5, act=act), + ResBlock(conv, n_feats, 5, act=act) + ) for _ in args.scale + ]) + + self.body = VisionTransformer(img_dim=args.patch_size, + patch_dim=args.patch_dim, + num_channels=n_feats, + embedding_dim=n_feats * args.patch_dim * args.patch_dim, + num_heads=args.num_heads, + num_layers=args.num_layers, + hidden_dim=n_feats * args.patch_dim * args.patch_dim * 4, + num_queries=args.num_queries, + dropout_rate=args.dropout_rate, + mlp=args.no_mlp, + pos_every=args.pos_every, + no_pos=args.no_pos) + + self.tail = nn.CellList([ + nn.SequentialCell( + Upsampler(conv, s, n_feats, act=False), + conv(n_feats, args.n_colors, kernel_size) + ) for s in args.scale + ]) + + self.reshape = P.Reshape() + self.tile = P.Tile() + self.transpose = P.Transpose() + + def construct(self, x): + """ipt""" + x = self.sub_mean(x) + x = self.head[self.scale_idx](x) + res = self.body(x, self.scale_idx) + res += x + x = self.tail[self.scale_idx](res) + x = self.add_mean(x) + + return x + + def set_scale(self, scale_idx): + """ipt""" + self.scale_idx = scale_idx + + def infrc(self, x): + """ipt""" + forward_function = self.forward_chop_new + + return forward_function(x) + + def forward_chop_new(self, x, shave=12, batchsize=64): + """ipt""" + h, w = x.shape[-2:] + padsize = int(self.args.patch_size) + shave = int(self.args.patch_size / 4) + scale = self.args.scale[self.scale_idx] + + h_cut = (h - padsize) % (padsize - shave) + w_cut = (w - padsize) % (padsize - shave) + + unf_1 = _stride_unfold_(padsize, stride=padsize - shave) + x_unfold = unf_1(x) + x_unfold = self.transpose(x_unfold, (1, 0, 2)) # transpose(0,2) + + x_hw_cut = x[:, :, (h - padsize):, (w - padsize):] + y_hw_cut = self.construct(x_hw_cut) + + x_h_cut = x[:, :, (h - padsize):, :] + x_w_cut = x[:, :, :, (w - padsize):] + y_h_cut = self.cut_h_new(x_h_cut, h, w, h_cut, + w_cut, padsize, shave, scale, batchsize) + y_w_cut = self.cut_w_new(x_w_cut, h, w, h_cut, + w_cut, padsize, shave, scale, batchsize) + + x_h_top = x[:, :, :padsize, :] + x_w_top = x[:, :, :, :padsize] + y_h_top = self.cut_h_new(x_h_top, h, w, h_cut, + w_cut, padsize, shave, scale, batchsize) + y_w_top = self.cut_w_new(x_w_top, h, w, h_cut, + w_cut, padsize, shave, scale, batchsize) + x_unfold = self.reshape( + x_unfold, (x_unfold.shape[0], -1, padsize, padsize)) + x_range = x_unfold.shape[0] // batchsize + \ + (x_unfold.shape[0] % batchsize != 0) + + cc_0 = P.Concat(axis=0) + for i in range(x_range): + if i == 0: + y_unfold = self.construct( + x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]) + else: + y_unfold = cc_0((y_unfold, self.construct( + x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]))) + y_unf_shape_0 = y_unfold.shape[0] + fold_1 = \ + _stride_fold_(padsize * scale, output_shape=((h - h_cut) * scale, (w - w_cut) * scale), + stride=padsize * scale - shave * scale) + y = fold_1(self.transpose(self.reshape( + y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1))) + cc_2 = P.Concat(axis=2) + cc_3 = P.Concat(axis=3) + y = cc_2((y_h_top, y[:, :, padsize * scale:, :])) + y = cc_3((y_w_top, y[:, :, :, padsize * scale:])) + y_unfold = y_unfold[:, :, int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale), + int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale)] + fold_2 = _stride_fold_(padsize * scale - shave * scale, + output_shape=((h - h_cut - shave) * + scale, (w - w_cut - shave) * scale), + stride=padsize * scale - shave * scale) + y_inter = fold_2(self.transpose(self.reshape( + y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1))) + y = cc_3((cc_3((y[:, :, :, :int(shave / 2 * scale)], cc_2((cc_2((y[:, :, :int(shave / 2 * scale), int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter)), y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)])))), y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) #pylint: disable=line-too-long + y = cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :], + y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) + y_w_cat = cc_2((y_w_cut[:, :, :y_w_cut.shape[2] - int((padsize - h_cut) / 2 * scale), :], + y_hw_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) + y = cc_3((y[:, :, :, :y.shape[3] - int((padsize - w_cut) / 2 * scale)], + y_w_cat[:, :, :, int((padsize - w_cut) / 2 * scale + 0.5):])) + return y + + def cut_h_new(self, x_h_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize): + """ipt""" + unf_1 = _stride_unfold_(padsize, stride=padsize - shave) + x_h_cut_unfold = unf_1(x_h_cut) + x_h_cut_unfold = self.transpose(x_h_cut_unfold, (1, 0, 2)) + + x_h_cut_unfold = self.reshape( + x_h_cut_unfold, (x_h_cut_unfold.shape[0], -1, padsize, padsize)) + x_range = x_h_cut_unfold.shape[0] // batchsize + \ + (x_h_cut_unfold.shape[0] % batchsize != 0) + cc_0 = P.Concat(axis=0) + for i in range(x_range): + if i == 0: + y_h_cut_unfold = self.construct( + x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]) + else: + y_h_cut_unfold = \ + cc_0((y_h_cut_unfold, self.construct( + x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]))) + y_h_cut_unfold_shape_0 = y_h_cut_unfold.shape[0] + fold_1 = \ + _stride_fold_(padsize * scale, output_shape=(padsize * scale, (w - w_cut) * scale), + stride=padsize * scale - shave * scale) + y_h_cut = fold_1(self.transpose(self.reshape( + y_h_cut_unfold, (y_h_cut_unfold_shape_0, -1, 1)), (2, 0, 1))) + y_h_cut_unfold = y_h_cut_unfold[:, :, :, int( + shave / 2 * scale):padsize * scale - int(shave / 2 * scale)] + fold_2 = _stride_fold_((padsize * scale, padsize * scale - shave * scale), + output_shape=(padsize * scale, + (w - w_cut - shave) * scale), + stride=padsize * scale - shave * scale) + y_h_cut_inter = fold_2(self.transpose(self.reshape( + y_h_cut_unfold, (y_h_cut_unfold_shape_0, -1, 1)), (2, 0, 1))) + cc_3 = P.Concat(axis=3) + y_h_cut = cc_3((cc_3((y_h_cut[:, :, :, :int(shave / 2 * scale)], + y_h_cut_inter)), y_h_cut[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) + return y_h_cut + + def cut_w_new(self, x_w_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize): + """ipt""" + unf_1 = _stride_unfold_(padsize, stride=padsize - shave) + x_w_cut_unfold = unf_1(x_w_cut) + x_w_cut_unfold = self.transpose(x_w_cut_unfold, (1, 0, 2)) + + x_w_cut_unfold = self.reshape( + x_w_cut_unfold, (x_w_cut_unfold.shape[0], -1, padsize, padsize)) + x_range = x_w_cut_unfold.shape[0] // batchsize + \ + (x_w_cut_unfold.shape[0] % batchsize != 0) + cc_0 = P.Concat(axis=0) + for i in range(x_range): + if i == 0: + y_w_cut_unfold = self.construct( + x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]) + else: + y_w_cut_unfold = cc_0((y_w_cut_unfold, + self.construct(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]))) + y_w_cut_unfold_shape_0 = y_w_cut_unfold.shape[0] + fold_1 = _stride_fold_(padsize * scale, + output_shape=((h - h_cut) * scale, + padsize * scale), + stride=padsize * scale - shave * scale) + y_w_cut = fold_1(self.transpose(self.reshape( + y_w_cut_unfold, (y_w_cut_unfold_shape_0, -1, 1)), (2, 0, 1))) + y_w_cut_unfold = y_w_cut_unfold[:, :, int( + shave / 2 * scale):padsize * scale - int(shave / 2 * scale), :] + fold_2 = _stride_fold_((padsize * scale - shave * scale, padsize * scale), + output_shape=((h - h_cut - shave) + * scale, padsize * scale), + stride=padsize * scale - shave * scale) + y_w_cut_inter = fold_2(self.transpose(self.reshape( + y_w_cut_unfold, (y_w_cut_unfold_shape_0, -1, 1)), (2, 0, 1))) + cc_2 = P.Concat(axis=2) + y_w_cut = cc_2((cc_2((y_w_cut[:, :, :int(shave / 2 * scale), :], + y_w_cut_inter)), y_w_cut[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, :])) + return y_w_cut + + +class _stride_unfold_(nn.Cell): + """ipt""" + + def __init__( + self, kernel_size, stride=-1): + + super(_stride_unfold_, self).__init__() + if stride == -1: + self.stride = kernel_size + else: + self.stride = stride + self.kernel_size = kernel_size + self.reshape = P.Reshape() + self.transpose = P.Transpose() + self.unfold = _unfold_(kernel_size) + + def construct(self, x): + """ipt""" + N, C, H, W = x.shape + leftup_idx_x = [] + leftup_idx_y = [] + nh = int((H - self.kernel_size) / self.stride + 1) + nw = int((W - self.kernel_size) / self.stride + 1) + for i in range(nh): + leftup_idx_x.append(i * self.stride) + for i in range(nw): + leftup_idx_y.append(i * self.stride) + NumBlock_x = len(leftup_idx_x) + NumBlock_y = len(leftup_idx_y) + zeroslike = P.ZerosLike() + cc_2 = P.Concat(axis=2) + cc_3 = P.Concat(axis=3) + unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size, + NumBlock_y * self.kernel_size), mstype.float32) + N, C, H, W = unf_x.shape + for i in range(NumBlock_x): + for j in range(NumBlock_y): + unf_i = i * self.kernel_size + unf_j = j * self.kernel_size + org_i = leftup_idx_x[i] + org_j = leftup_idx_y[j] + fills = x[:, :, org_i:org_i + self.kernel_size, + org_j:org_j + self.kernel_size] + unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]), + cc_2( + (cc_2((zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)), + zeroslike(unf_x[:, :, unf_i + self.kernel_size:, + unf_j:unf_j + self.kernel_size]))))), + zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:]))) + y = self.unfold(unf_x) + return y + +class _stride_fold_(nn.Cell): + """ipt""" + + def __init__( + self, kernel_size, output_shape=(-1, -1), stride=-1): + + super(_stride_fold_, self).__init__() + + if isinstance(kernel_size, (list, tuple)): + self.kernel_size = kernel_size + else: + self.kernel_size = [kernel_size, kernel_size] + + if stride == -1: + self.stride = kernel_size[0] + else: + self.stride = stride + + self.output_shape = output_shape + self.reshape = P.Reshape() + self.transpose = P.Transpose() + self.fold = _fold_(kernel_size) + + def construct(self, x): + """ipt""" + cc_2 = P.Concat(axis=2) + cc_3 = P.Concat(axis=3) + zeroslike = P.ZerosLike() + if self.output_shape[0] == -1: + large_x = self.fold(x) + N, C, H, _ = large_x.shape + leftup_idx = [] + for i in range(0, H, self.kernel_size[0]): + leftup_idx.append(i) + NumBlock = len(leftup_idx) + fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0], + (NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32) + + for i in range(NumBlock): + for j in range(NumBlock): + fold_i = i * self.stride + fold_j = j * self.stride + org_i = leftup_idx[i] + org_j = leftup_idx[j] + fills = large_x[:, :, org_i:org_i + self.kernel_size[0], + org_j:org_j + self.kernel_size[1]] + fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2((zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) #pylint: disable=line-too-long + y = fold_x + else: + NumBlock_x = int( + (self.output_shape[0] - self.kernel_size[0]) / self.stride + 1) + NumBlock_y = int( + (self.output_shape[1] - self.kernel_size[1]) / self.stride + 1) + large_shape = [NumBlock_x * self.kernel_size[0], + NumBlock_y * self.kernel_size[1]] + self.fold = _fold_(self.kernel_size, large_shape) + large_x = self.fold(x) + N, C, H, _ = large_x.shape + leftup_idx_x = [] + leftup_idx_y = [] + for i in range(NumBlock_x): + leftup_idx_x.append(i * self.kernel_size[0]) + for i in range(NumBlock_y): + leftup_idx_y.append(i * self.kernel_size[1]) + fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], + (NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32) + for i in range(NumBlock_x): + for j in range(NumBlock_y): + fold_i = i * self.stride + fold_j = j * self.stride + org_i = leftup_idx_x[i] + org_j = leftup_idx_y[j] + fills = large_x[:, :, org_i:org_i + self.kernel_size[0], + org_j:org_j + self.kernel_size[1]] + fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2((zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) #pylint: disable=line-too-long + y = fold_x + return y + + +class _unfold_(nn.Cell): + """ipt""" + + def __init__( + self, kernel_size, stride=-1): + + super(_unfold_, self).__init__() + if stride == -1: + self.stride = kernel_size + self.kernel_size = kernel_size + + self.reshape = P.Reshape() + self.transpose = P.Transpose() + + def construct(self, x): + """ipt""" + N, C, H, W = x.shape + numH = int(H / self.kernel_size) + numW = int(W / self.kernel_size) + if numH * self.kernel_size != H or numW * self.kernel_size != W: + x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size] + output_img = self.reshape(x, (N, C, numH, self.kernel_size, W)) + + output_img = self.transpose(output_img, (0, 1, 2, 4, 3)) + + output_img = self.reshape(output_img, (N, C, int( + numH * numW), self.kernel_size, self.kernel_size)) + + output_img = self.transpose(output_img, (0, 2, 1, 4, 3)) + + output_img = self.reshape(output_img, (N, int(numH * numW), -1)) + return output_img + + +class _fold_(nn.Cell): + """ipt""" + + def __init__( + self, kernel_size, output_shape=(-1, -1), stride=-1): + + super(_fold_, self).__init__() + + # if isinstance(kernel_size, list) or isinstance(kernel_size, tuple): + if isinstance(kernel_size, (list, tuple)): + self.kernel_size = kernel_size + else: + self.kernel_size = [kernel_size, kernel_size] + + if stride == -1: + self.stride = self.kernel_size[0] + self.output_shape = output_shape + + self.reshape = P.Reshape() + self.transpose = P.Transpose() + + def construct(self, x): + """ipt""" + N, C, L = x.shape + org_C = int(L / self.kernel_size[0] / self.kernel_size[1]) + if self.output_shape[0] == -1: + numH = int(np.sqrt(C)) + numW = int(np.sqrt(C)) + org_H = int(numH * self.kernel_size[0]) + org_W = org_H + else: + org_H = int(self.output_shape[0]) + org_W = int(self.output_shape[1]) + numH = int(org_H / self.kernel_size[0]) + numW = int(org_W / self.kernel_size[1]) + + output_img = self.reshape( + x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1])) + + output_img = self.transpose(output_img, (0, 2, 1, 3, 4)) + output_img = self.reshape( + output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1])) + + output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5)) + + output_img = self.reshape(output_img, (N, org_C, org_H, org_W)) + return output_img diff --git a/model_zoo/research/cv/IPT/src/metrics.py b/model_zoo/research/cv/IPT/src/metrics.py new file mode 100755 index 00000000000..3fe8b70e287 --- /dev/null +++ b/model_zoo/research/cv/IPT/src/metrics.py @@ -0,0 +1,56 @@ +'''metrics''' +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import math +import numpy as np + + +def quantize(img, rgb_range): + '''metrics''' + pixel_range = 255 / rgb_range + img = np.multiply(img, pixel_range) + img = np.clip(img, 0, 255) + img = np.round(img) / pixel_range + return img + + +def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None): + '''metrics''' + hr = np.float32(hr) + sr = np.float32(sr) + diff = (sr - hr) / rgb_range + gray_coeffs = np.array([65.738, 129.057, 25.064] + ).reshape((1, 3, 1, 1)) / 256 + diff = np.multiply(diff, gray_coeffs).sum(1) + if hr.size == 1: + return 0 + if scale != 1: + shave = scale + else: + shave = scale + 6 + if scale == 1: + valid = diff + else: + valid = diff[..., shave:-shave, shave:-shave] + mse = np.mean(pow(valid, 2)) + return -10 * math.log10(mse) + + +def rgb2ycbcr(img, y_only=True): + '''metrics''' + img.astype(np.float32) + if y_only: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + return rlt diff --git a/model_zoo/research/cv/IPT/src/template.py b/model_zoo/research/cv/IPT/src/template.py new file mode 100755 index 00000000000..3314d6d94ae --- /dev/null +++ b/model_zoo/research/cv/IPT/src/template.py @@ -0,0 +1,67 @@ +'''temp''' +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +def set_template(args): + '''temp''' + if args.template.find('jpeg') >= 0: + args.data_train = 'DIV2K_jpeg' + args.data_test = 'DIV2K_jpeg' + args.epochs = 200 + args.decay = '100' + + if args.template.find('EDSR_paper') >= 0: + args.model = 'EDSR' + args.n_resblocks = 32 + args.n_feats = 256 + args.res_scale = 0.1 + + if args.template.find('MDSR') >= 0: + args.model = 'MDSR' + args.patch_size = 48 + args.epochs = 650 + + if args.template.find('DDBPN') >= 0: + args.model = 'DDBPN' + args.patch_size = 128 + args.scale = '4' + + args.data_test = 'Set5' + + args.batch_size = 20 + args.epochs = 1000 + args.decay = '500' + args.gamma = 0.1 + args.weight_decay = 1e-4 + + args.loss = '1*MSE' + + if args.template.find('GAN') >= 0: + args.epochs = 200 + args.lr = 5e-5 + args.decay = '150' + + if args.template.find('RCAN') >= 0: + args.model = 'RCAN' + args.n_resgroups = 10 + args.n_resblocks = 20 + args.n_feats = 64 + args.chop = True + + if args.template.find('VDSR') >= 0: + args.model = 'VDSR' + args.n_resblocks = 20 + args.n_feats = 64 + args.patch_size = 41 + args.lr = 1e-1