forked from mindspore-Ecosystem/mindspore
add IPT net
This commit is contained in:
parent
e0ea2767b7
commit
9022e552fd
|
@ -1,6 +1,6 @@
|
|||
"""eval script"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
@ -13,48 +13,82 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from src import ipt
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.args import args
|
||||
import src.ipt_model as ipt
|
||||
from src.data.srdata import SRData
|
||||
from src.metrics import calc_psnr, quantize
|
||||
|
||||
from mindspore import context
|
||||
import mindspore.dataset as de
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
|
||||
context.set_context(max_call_depth=10000)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="ASCEND", device_id=0)
|
||||
def sub_mean(x):
|
||||
red_channel_mean = 0.4488 * 255
|
||||
green_channel_mean = 0.4371 * 255
|
||||
blue_channel_mean = 0.4040 * 255
|
||||
x[:, 0, :, :] -= red_channel_mean
|
||||
x[:, 1, :, :] -= green_channel_mean
|
||||
x[:, 2, :, :] -= blue_channel_mean
|
||||
return x
|
||||
|
||||
def add_mean(x):
|
||||
red_channel_mean = 0.4488 * 255
|
||||
green_channel_mean = 0.4371 * 255
|
||||
blue_channel_mean = 0.4040 * 255
|
||||
x[:, 0, :, :] += red_channel_mean
|
||||
x[:, 1, :, :] += green_channel_mean
|
||||
x[:, 2, :, :] += blue_channel_mean
|
||||
return x
|
||||
|
||||
def main():
|
||||
def eval_net():
|
||||
"""eval"""
|
||||
args.batch_size = 128
|
||||
args.decay = 70
|
||||
args.patch_size = 48
|
||||
args.num_queries = 6
|
||||
args.model = 'vtip'
|
||||
args.num_layers = 4
|
||||
|
||||
if args.epochs == 0:
|
||||
args.epochs = 1e8
|
||||
|
||||
for arg in vars(args):
|
||||
if vars(args)[arg] == 'True':
|
||||
vars(args)[arg] = True
|
||||
elif vars(args)[arg] == 'False':
|
||||
vars(args)[arg] = False
|
||||
train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
|
||||
train_de_dataset = de.GeneratorDataset(train_dataset, ['LR', "HR"], shuffle=False)
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR', "idx", "filename"], shuffle=False)
|
||||
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
|
||||
train_loader = train_de_dataset.create_dict_iterator()
|
||||
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
|
||||
|
||||
net_m = ipt.IPT(args)
|
||||
print('load mindspore net successfully.')
|
||||
if args.pth_path:
|
||||
param_dict = load_checkpoint(args.pth_path)
|
||||
load_param_into_net(net_m, param_dict)
|
||||
net_m.set_train(False)
|
||||
idx = Tensor(np.ones(args.task_id), mstype.int32)
|
||||
inference = ipt.IPT_post(net_m, args)
|
||||
print('load mindspore net successfully.')
|
||||
num_imgs = train_de_dataset.get_dataset_size()
|
||||
psnrs = np.zeros((num_imgs, 1))
|
||||
inference = ipt.IPT_post(net_m, args)
|
||||
for batch_idx, imgs in enumerate(train_loader):
|
||||
lr = imgs['LR']
|
||||
hr = imgs['HR']
|
||||
hr_np = np.float32(hr.asnumpy())
|
||||
pred = inference.forward(lr)
|
||||
pred_np = np.float32(pred.asnumpy())
|
||||
lr = sub_mean(lr)
|
||||
lr = Tensor(lr, mstype.float32)
|
||||
pred = inference.forward(lr, idx)
|
||||
pred_np = add_mean(pred.asnumpy())
|
||||
pred_np = quantize(pred_np, 255)
|
||||
psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True)
|
||||
psnr = calc_psnr(pred_np, hr, 4, 255.0)
|
||||
print("current psnr: ", psnr)
|
||||
psnrs[batch_idx, 0] = psnr
|
||||
if args.denoise:
|
||||
print('Mean psnr of %s DN_%s is %.4f' % (args.data_test[0], args.sigma, psnrs.mean(axis=0)[0]))
|
||||
|
@ -63,7 +97,6 @@ def main():
|
|||
else:
|
||||
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("Start main function!")
|
||||
main()
|
||||
print("Start eval function!")
|
||||
eval_net()
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.vitm import ViT
|
||||
|
||||
|
||||
def IPT(*args, **kwargs):
|
||||
return ViT(*args, **kwargs)
|
||||
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == 'IPT':
|
||||
return IPT(*args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -45,9 +45,9 @@ The result images are converted into YCbCr color space. The PSNR is evaluated on
|
|||
|
||||
## Requirements
|
||||
|
||||
### Hardware (GPU)
|
||||
### Hardware (Ascend)
|
||||
|
||||
> Prepare hardware environment with GPU.
|
||||
> Prepare hardware environment with Ascend.
|
||||
|
||||
### Framework
|
||||
|
||||
|
@ -67,34 +67,73 @@ The result images are converted into YCbCr color space. The PSNR is evaluated on
|
|||
```bash
|
||||
IPT
|
||||
├── eval.py # inference entry
|
||||
├── train.py # pre-training entry
|
||||
├── train_finetune.py # fine-tuning entry
|
||||
├── image
|
||||
│ └── ipt.png # the illustration of IPT network
|
||||
├── model
|
||||
│ ├── IPT_denoise30.ckpt # denoise model weights for noise level 30
|
||||
│ ├── IPT_denoise50.ckpt # denoise model weights for noise level 50
|
||||
│ ├── IPT_derain.ckpt # derain model weights
|
||||
│ ├── IPT_sr2.ckpt # X2 super-resolution model weights
|
||||
│ ├── IPT_sr3.ckpt # X3 super-resolution model weights
|
||||
│ └── IPT_sr4.ckpt # X4 super-resolution model weights
|
||||
├── readme.md # Readme
|
||||
├── scripts
|
||||
│ └── run_eval.sh # inference script for all tasks
|
||||
│ ├── run_eval.sh # inference script for all tasks
|
||||
│ ├── run_distributed.sh # pre-training script for all tasks
|
||||
│ └── run_finetune_distributed.sh # fine-tuning script for all tasks
|
||||
└── src
|
||||
├── args.py # options/hyper-parameters of IPT
|
||||
├── data
|
||||
│ ├── common.py # common dataset
|
||||
│ ├── __init__.py # Class data init function
|
||||
│ └── srdata.py # flow of loading sr data
|
||||
├── foldunfold_stride.py # function of fold and unfold operations for images
|
||||
│ ├── bicubic.py # scripts for data pre-processing
|
||||
│ ├── div2k.py # DIV2K dataset
|
||||
│ ├── imagenet.py # Imagenet data for pre-training
|
||||
│ └── srdata.py # All dataset
|
||||
├── metrics.py # PSNR calculator
|
||||
├── template.py # setting of model selection
|
||||
└── vitm.py # IPT network
|
||||
├── utils.py # training scripts
|
||||
├── loss.py # contrastive_loss
|
||||
└── ipt_model.py # IPT network
|
||||
```
|
||||
|
||||
### Script Parameter
|
||||
|
||||
> For details about hyperparameters, see src/args.py.
|
||||
|
||||
## Training Process
|
||||
|
||||
### For pre-training
|
||||
|
||||
```bash
|
||||
python train.py --distribute --imagenet 1 --batch_size 64 --lr 5e-5 --scale 2+3+4+1+1+1 --alltask --react --model vtip --num_queries 6 --chop_new --num_layers 4 --data_train imagenet --dir_data $DATA_PATH --derain --save $SAVE_PATH
|
||||
```
|
||||
|
||||
> Or one can run following script for all tasks.
|
||||
|
||||
```bash
|
||||
sh scripts/run_distributed.sh RANK_TABLE_FILE DATA_PATH
|
||||
```
|
||||
|
||||
### For fine-tuning
|
||||
|
||||
> For SR tasks:
|
||||
|
||||
```bash
|
||||
python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --epochs 50
|
||||
```
|
||||
|
||||
> For Denoising tasks:
|
||||
|
||||
```bash
|
||||
python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --denoise --sigma $Noise --epochs 50
|
||||
```
|
||||
|
||||
> For deraining tasks:
|
||||
|
||||
```bash
|
||||
python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --derain --epochs 50
|
||||
```
|
||||
|
||||
> Or one can run following script for all tasks.
|
||||
|
||||
```bash
|
||||
sh scripts/run_finetune_distributed.sh RANK_TABLE_FILE DATA_PATH MODEL TASK_ID
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Evaluation Process
|
||||
|
@ -103,13 +142,13 @@ IPT
|
|||
> For SR x4:
|
||||
|
||||
```bash
|
||||
python eval.py --dir_data ../../data/ --data_test Set14 --nochange --test_only --ext img --chop_new --scale 4 --pth_path ./model/IPT_sr4.ckpt
|
||||
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale $SCALE
|
||||
```
|
||||
|
||||
> Or one can run following script for all tasks.
|
||||
|
||||
```bash
|
||||
sh scripts/run_eval.sh
|
||||
sh scripts/run_eval.sh DATA_PATH DATA_TEST MODEL TASK_ID
|
||||
```
|
||||
|
||||
### Evaluation Result
|
||||
|
@ -117,7 +156,7 @@ sh scripts/run_eval.sh
|
|||
The result are evaluated by the value of PSNR (Peak Signal-to-Noise Ratio), and the format is as following.
|
||||
|
||||
```bash
|
||||
result: {"Mean psnr of Se5 x4 is 32.68"}
|
||||
result: {"Mean psnr of Set5 x4 is 32.68"}
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
RANK_TABLE_FILE=$(realpath $1)
|
||||
export RANK_TABLE_FILE
|
||||
export DATA_PATH=$2
|
||||
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env$i.log
|
||||
python train.py --distribute --imagenet 1 --batch_size 64 --lr 5e-5 --scale 2+3+4+1+1+1 --alltask --react --model vtip --num_queries 6 --chop_new --num_layers 4 --data_train imagenet --dir_data $DATA_PATH --derain --save experiments/ckpt_new_init/ > log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -14,18 +14,49 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
DATA_SET=$3
|
||||
PATH_CHECKPOINT=$4
|
||||
ulimit -u unlimited
|
||||
export DATA_PATH=$1
|
||||
export DATA_TEST=$2
|
||||
export MODEL=$3
|
||||
export TASK_ID=$4
|
||||
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 4 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 3 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 2 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
if [[ $TASK_ID -lt 3 ]]; then
|
||||
mkdir ./run_eval$TASK_ID
|
||||
cp -r ../src ./run_eval$TASK_ID
|
||||
cp ../*.py ./run_eval$TASK_ID
|
||||
echo "start evaluation for Task $TASK_ID, device $DEVICE_ID"
|
||||
cd ./run_eval$TASK_ID ||exit
|
||||
env > env$TASK_ID.log
|
||||
SCALE=$[$TASK_ID+2]
|
||||
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale $SCALE > log 2>&1 &
|
||||
fi
|
||||
|
||||
##denoise
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 30 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 50 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
if [[ $TASK_ID -eq 3 ]]; then
|
||||
mkdir ./run_eval$TASK_ID
|
||||
cp -r ../src ./run_eval$TASK_ID
|
||||
cp ../*.py ./run_eval$TASK_ID
|
||||
echo "start evaluation for Task $TASK_ID, device $DEVICE_ID"
|
||||
cd ./run_eval$TASK_ID ||exit
|
||||
env > env$TASK_ID.log
|
||||
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale 1 --derain > log 2>&1 &
|
||||
fi
|
||||
|
||||
##derain
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --derain --derain_test 1 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
if [[ $TASK_ID -eq 4 ]]; then
|
||||
mkdir ./run_eval$TASK_ID
|
||||
cp -r ../src ./run_eval$TASK_ID
|
||||
cp ../*.py ./run_eval$TASK_ID
|
||||
echo "start evaluation for Task $TASK_ID, device $DEVICE_ID"
|
||||
cd ./run_eval$TASK_ID ||exit
|
||||
env > env$TASK_ID.log
|
||||
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale 1 --denoise --sigma 30 > log 2>&1 &
|
||||
fi
|
||||
|
||||
if [[ $TASK_ID -eq 5 ]]; then
|
||||
mkdir ./run_eval$TASK_ID
|
||||
cp -r ../src ./run_eval$TASK_ID
|
||||
cp ../*.py ./run_eval$TASK_ID
|
||||
echo "start evaluation for Task $TASK_ID, device $DEVICE_ID"
|
||||
cd ./run_eval$TASK_ID ||exit
|
||||
env > env$TASK_ID.log
|
||||
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale 1 --denoise --sigma 50 > log 2>&1 &
|
||||
fi
|
|
@ -0,0 +1,43 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
RANK_TABLE_FILE=$(realpath $1)
|
||||
export RANK_TABLE_FILE
|
||||
export DATA_PATH=$2
|
||||
export MODEL=$3
|
||||
export TASK_ID=$4
|
||||
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
|
||||
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env$i.log
|
||||
python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --epochs 100 > log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -1,4 +1,4 @@
|
|||
'''args'''
|
||||
"""args"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -13,8 +13,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
from src import template
|
||||
|
||||
parser = argparse.ArgumentParser(description='EDSR and MDSR')
|
||||
|
||||
|
@ -24,12 +24,6 @@ parser.add_argument('--template', default='.',
|
|||
help='You can set various templates in option.py')
|
||||
|
||||
# Hardware specifications
|
||||
parser.add_argument('--n_threads', type=int, default=6,
|
||||
help='number of threads for data loading')
|
||||
parser.add_argument('--cpu', action='store_true',
|
||||
help='use cpu only')
|
||||
parser.add_argument('--n_GPUs', type=int, default=1,
|
||||
help='number of GPUs')
|
||||
parser.add_argument('--seed', type=int, default=1,
|
||||
help='random seed')
|
||||
|
||||
|
@ -60,9 +54,8 @@ parser.add_argument('--no_augment', action='store_true',
|
|||
help='do not use data augmentation')
|
||||
|
||||
# Model specifications
|
||||
parser.add_argument('--model', default='vtip',
|
||||
parser.add_argument('--model', default='EDSR',
|
||||
help='model name')
|
||||
|
||||
parser.add_argument('--act', type=str, default='relu',
|
||||
help='activation function')
|
||||
parser.add_argument('--pre_train', type=str, default='',
|
||||
|
@ -139,6 +132,7 @@ parser.add_argument('--gclip', type=float, default=0,
|
|||
help='gradient clipping threshold (0 = no clipping)')
|
||||
|
||||
# Loss specifications
|
||||
parser.add_argument('--con_loss', action='store_true')
|
||||
parser.add_argument('--loss', type=str, default='1*L1',
|
||||
help='loss function configuration')
|
||||
parser.add_argument('--skip_threshold', type=float, default='1e8',
|
||||
|
@ -161,6 +155,7 @@ parser.add_argument('--save_gt', action='store_true',
|
|||
help='save low-resolution and high-resolution images together')
|
||||
|
||||
parser.add_argument('--scalelr', type=int, default=0)
|
||||
|
||||
# cloud
|
||||
parser.add_argument('--moxfile', type=int, default=1)
|
||||
parser.add_argument('--imagenet', type=int, default=0)
|
||||
|
@ -169,10 +164,11 @@ parser.add_argument('--train_url', type=str, help='train_dir')
|
|||
parser.add_argument('--pretrain', type=str, default='')
|
||||
parser.add_argument('--pth_path', type=str, default='')
|
||||
parser.add_argument('--load_query', type=int, default=0)
|
||||
|
||||
# transformer
|
||||
parser.add_argument('--patch_dim', type=int, default=3)
|
||||
parser.add_argument('--num_heads', type=int, default=12)
|
||||
parser.add_argument('--num_layers', type=int, default=12)
|
||||
parser.add_argument('--num_layers', type=int, default=4)
|
||||
parser.add_argument('--dropout_rate', type=float, default=0)
|
||||
parser.add_argument('--no_norm', action='store_true')
|
||||
parser.add_argument('--post_norm', action='store_true')
|
||||
|
@ -192,8 +188,10 @@ parser.add_argument('--sigma', type=float, default=25)
|
|||
parser.add_argument('--derain', action='store_true')
|
||||
parser.add_argument('--finetune', action='store_true')
|
||||
parser.add_argument('--derain_test', type=int, default=10)
|
||||
|
||||
# alltask
|
||||
parser.add_argument('--alltask', action='store_true')
|
||||
parser.add_argument('--task_id', type=int, default=0)
|
||||
|
||||
# dehaze
|
||||
parser.add_argument('--dehaze', action='store_true')
|
||||
|
@ -201,6 +199,7 @@ parser.add_argument('--dehaze_test', type=int, default=100)
|
|||
parser.add_argument('--indoor', action='store_true')
|
||||
parser.add_argument('--outdoor', action='store_true')
|
||||
parser.add_argument('--nochange', action='store_true')
|
||||
|
||||
# deblur
|
||||
parser.add_argument('--deblur', action='store_true')
|
||||
parser.add_argument('--deblur_test', type=int, default=1000)
|
||||
|
@ -210,6 +209,8 @@ parser.add_argument('--init_method', type=str,
|
|||
default=None, help='master address')
|
||||
parser.add_argument('--rank', type=int, default=0,
|
||||
help='Index of current task')
|
||||
parser.add_argument('--group_size', type=int, default=0,
|
||||
help='Index of current task')
|
||||
parser.add_argument('--world_size', type=int, default=1,
|
||||
help='Total number of tasks')
|
||||
parser.add_argument('--gpu', default=None, type=int,
|
||||
|
@ -223,7 +224,6 @@ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
|||
parser.add_argument('--distribute', action='store_true')
|
||||
|
||||
args, unparsed = parser.parse_known_args()
|
||||
template.set_template(args)
|
||||
|
||||
args.scale = [int(x) for x in args.scale.split("+")]
|
||||
args.data_train = args.data_train.split('+')
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
"""data"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd"
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
class Data:
|
||||
"""data"""
|
||||
|
||||
def __init__(self, args):
|
||||
self.loader_train = None
|
||||
self.loader_test = []
|
||||
for d in args.data_test:
|
||||
if d in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109', 'CBSD68', 'Rain100L', 'GOPRO_Large']:
|
||||
m = import_module('data.benchmark')
|
||||
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
|
||||
else:
|
||||
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
|
||||
m = import_module('data.' + module_name.lower())
|
||||
testset = getattr(m, module_name)(args, train=False, name=d)
|
||||
|
||||
self.loader_test.append(
|
||||
testset
|
||||
)
|
|
@ -0,0 +1,132 @@
|
|||
"""bicubic"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
|
||||
class bicubic:
|
||||
"""bicubic"""
|
||||
def __init__(self, seed=0):
|
||||
self.seed = seed
|
||||
self.rand_fn = np.random.RandomState(self.seed)
|
||||
|
||||
def cubic(self, x):
|
||||
absx2 = np.abs(x) * np.abs(x)
|
||||
absx3 = np.abs(x) * np.abs(x) * np.abs(x)
|
||||
|
||||
condition1 = (np.abs(x) <= 1).astype(np.float32)
|
||||
condition2 = ((np.abs(x) > 1) & (np.abs(x) <= 2)).astype(np.float32)
|
||||
|
||||
f = (1.5 * absx3 - 2.5 * absx2 + 1) * condition1 + (-0.5 * absx3 + 2.5 * absx2 - 4 * np.abs(x) + 2) * condition2
|
||||
return f
|
||||
|
||||
def contribute(self, in_size, out_size, scale):
|
||||
"""bicubic"""
|
||||
kernel_width = 4
|
||||
if scale < 1:
|
||||
kernel_width = 4 / scale
|
||||
x0 = np.arange(start=1, stop=out_size[0]+1).astype(np.float32)
|
||||
x1 = np.arange(start=1, stop=out_size[1]+1).astype(np.float32)
|
||||
|
||||
u0 = x0 / scale + 0.5 * (1 - 1 / scale)
|
||||
u1 = x1 / scale + 0.5 * (1 - 1 / scale)
|
||||
|
||||
left0 = np.floor(u0 - kernel_width / 2)
|
||||
left1 = np.floor(u1 - kernel_width / 2)
|
||||
|
||||
width = np.ceil(kernel_width) + 2
|
||||
indice0 = np.expand_dims(left0, axis=1) + \
|
||||
np.expand_dims(np.arange(start=0, stop=width).astype(np.float32), axis=0)
|
||||
indice1 = np.expand_dims(left1, axis=1) + \
|
||||
np.expand_dims(np.arange(start=0, stop=width).astype(np.float32), axis=0)
|
||||
|
||||
mid0 = np.expand_dims(u0, axis=1) - np.expand_dims(indice0, axis=0)
|
||||
mid1 = np.expand_dims(u1, axis=1) - np.expand_dims(indice1, axis=0)
|
||||
|
||||
if scale < 1:
|
||||
weight0 = scale * self.cubic(mid0 * scale)
|
||||
weight1 = scale * self.cubic(mid1 * scale)
|
||||
else:
|
||||
weight0 = self.cubic(mid0)
|
||||
weight1 = self.cubic(mid1)
|
||||
|
||||
weight0 = weight0 / (np.expand_dims(np.sum(weight0, axis=2), 2))
|
||||
weight1 = weight1 / (np.expand_dims(np.sum(weight1, axis=2), 2))
|
||||
|
||||
indice0 = np.expand_dims(np.minimum(np.maximum(1, indice0), in_size[0]), axis=0)
|
||||
indice1 = np.expand_dims(np.minimum(np.maximum(1, indice1), in_size[1]), axis=0)
|
||||
kill0 = np.equal(weight0, 0)[0][0]
|
||||
kill1 = np.equal(weight1, 0)[0][0]
|
||||
|
||||
weight0 = weight0[:, :, kill0 == 0]
|
||||
weight1 = weight1[:, :, kill1 == 0]
|
||||
|
||||
indice0 = indice0[:, :, kill0 == 0]
|
||||
indice1 = indice1[:, :, kill1 == 0]
|
||||
|
||||
return weight0, weight1, indice0, indice1
|
||||
|
||||
def forward(self, hr, rain, lrx2, lrx3, lrx4, filename, batchInfo):
|
||||
"""bicubic"""
|
||||
idx = self.rand_fn.randint(0, 6)
|
||||
if idx < 3:
|
||||
if idx == 0:
|
||||
scale = 1/2
|
||||
hr = lrx2
|
||||
elif idx == 1:
|
||||
scale = 1/3
|
||||
hr = lrx3
|
||||
elif idx == 2:
|
||||
scale = 1/4
|
||||
hr = lrx4
|
||||
hr = np.array(hr)
|
||||
[_, _, h, w] = hr.shape
|
||||
weight0, weight1, indice0, indice1 = self.contribute([h, w], [int(h * scale), int(w * scale)], scale)
|
||||
weight0 = np.asarray(weight0[0], dtype=np.float32)
|
||||
|
||||
indice0 = np.asarray(indice0[0], dtype=np.float32).astype(np.long)
|
||||
weight0 = np.expand_dims(np.expand_dims(np.expand_dims(weight0, axis=0), axis=1), axis=4)
|
||||
out = hr[:, :, (indice0-1), :] * weight0
|
||||
out = np.sum(out, axis=3)
|
||||
A = np.transpose(out, (0, 1, 3, 2))
|
||||
|
||||
weight1 = np.asarray(weight1[0], dtype=np.float32)
|
||||
weight1 = np.expand_dims(np.expand_dims(np.expand_dims(weight1, axis=0), axis=1), axis=4)
|
||||
indice1 = np.asarray(indice1[0], dtype=np.float32).astype(np.long)
|
||||
out = A[:, :, (indice1-1), :] * weight1
|
||||
out = np.round(255 * np.transpose(np.sum(out, axis=3), (0, 1, 3, 2)))/255
|
||||
out = np.clip(np.round(out), 0, 255)
|
||||
lr = list(out)
|
||||
hr = list(hr)
|
||||
else:
|
||||
if idx == 3:
|
||||
hr = np.array(hr)
|
||||
rain = np.array(rain)
|
||||
lr = np.clip((rain + hr), 0, 255)
|
||||
hr = list(hr)
|
||||
lr = list(lr)
|
||||
elif idx == 4:
|
||||
hr = np.array(hr)
|
||||
noise = np.random.randn(*hr.shape) * 30
|
||||
lr = np.clip(noise + hr, 0, 255)
|
||||
hr = list(hr)
|
||||
lr = list(lr)
|
||||
elif idx == 5:
|
||||
hr = np.array(hr)
|
||||
noise = np.random.randn(*hr.shape) * 50
|
||||
lr = np.clip(noise + hr, 0, 255)
|
||||
hr = list(hr)
|
||||
lr = list(lr)
|
||||
return lr, hr, [idx] * len(hr), filename
|
|
@ -1,6 +1,6 @@
|
|||
"""common"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
@ -13,13 +13,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import skimage.color as sc
|
||||
|
||||
|
||||
def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
|
||||
def get_patch(*args, patch_size=96, scale=2, input_large=False):
|
||||
"""common"""
|
||||
ih, iw = args[0].shape[:2]
|
||||
|
||||
|
@ -34,25 +32,19 @@ def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
|
|||
else:
|
||||
tx, ty = ix, iy
|
||||
|
||||
ret = [
|
||||
args[0][iy:iy + ip, ix:ix + ip, :],
|
||||
*[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
|
||||
]
|
||||
ret = [args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]]
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def set_channel(*args, n_channels=3):
|
||||
"""common"""
|
||||
|
||||
def _set_channel(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
|
||||
c = img.shape[2]
|
||||
if n_channels == 1 and c == 3:
|
||||
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
|
||||
elif n_channels == 3 and c == 1:
|
||||
if n_channels == 3 and c == 1:
|
||||
img = np.concatenate([img] * n_channels, 2)
|
||||
|
||||
return img[:, :, :n_channels]
|
||||
|
@ -61,14 +53,11 @@ def set_channel(*args, n_channels=3):
|
|||
|
||||
|
||||
def np2Tensor(*args, rgb_range=255):
|
||||
"""common"""
|
||||
|
||||
def _np2Tensor(img):
|
||||
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
|
||||
tensor = np_transpose.astype(np.float32)
|
||||
tensor = tensor * (rgb_range / 255)
|
||||
return tensor
|
||||
|
||||
input_data = np_transpose.astype(np.float32)
|
||||
output = input_data * (rgb_range / 255)
|
||||
return output
|
||||
return [_np2Tensor(a) for a in args]
|
||||
|
||||
|
||||
|
@ -79,6 +68,7 @@ def augment(*args, hflip=True, rot=True):
|
|||
rot90 = rot and random.random() < 0.5
|
||||
|
||||
def _augment(img):
|
||||
"""common"""
|
||||
if hflip:
|
||||
img = img[:, ::-1, :]
|
||||
if vflip:
|
||||
|
@ -88,3 +78,18 @@ def augment(*args, hflip=True, rot=True):
|
|||
return img
|
||||
|
||||
return [_augment(a) for a in args]
|
||||
|
||||
|
||||
def search(root, target="JPEG"):
|
||||
"""srdata"""
|
||||
item_list = []
|
||||
items = os.listdir(root)
|
||||
for item in items:
|
||||
path = os.path.join(root, item)
|
||||
if os.path.isdir(path):
|
||||
item_list.extend(search(path, target))
|
||||
elif path.split('/')[-1].startswith(target):
|
||||
item_list.append(path)
|
||||
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
|
||||
item_list.append(path)
|
||||
return item_list
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
"""div2k"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
from src.data.srdata import SRData
|
||||
|
||||
class DIV2K(SRData):
|
||||
"""DIV2K"""
|
||||
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
|
||||
data_range = [r.split('-') for r in args.data_range.split('/')]
|
||||
if train:
|
||||
data_range = data_range[0]
|
||||
else:
|
||||
if args.test_only and len(data_range) == 1:
|
||||
data_range = data_range[0]
|
||||
else:
|
||||
data_range = data_range[1]
|
||||
|
||||
self.begin, self.end = list(map(int, data_range))
|
||||
super(DIV2K, self).__init__(args, name=name, train=train, benchmark=benchmark)
|
||||
|
||||
def _scan(self):
|
||||
names_hr, names_lr = super(DIV2K, self)._scan()
|
||||
names_hr = names_hr[self.begin - 1:self.end]
|
||||
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
|
||||
|
||||
return names_hr, names_lr
|
||||
|
||||
def _set_filesystem(self, dir_data):
|
||||
super(DIV2K, self)._set_filesystem(dir_data)
|
||||
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
|
||||
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
|
|
@ -0,0 +1,171 @@
|
|||
"""imagent"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import random
|
||||
import io
|
||||
from PIL import Image
|
||||
|
||||
import numpy as np
|
||||
import imageio
|
||||
|
||||
def search(root, target="JPEG"):
|
||||
"""imagent"""
|
||||
item_list = []
|
||||
items = os.listdir(root)
|
||||
for item in items:
|
||||
path = os.path.join(root, item)
|
||||
if os.path.isdir(path):
|
||||
item_list.extend(search(path, target))
|
||||
elif path.split('.')[-1] == target:
|
||||
item_list.append(path)
|
||||
elif path.split('/')[-1].startswith(target):
|
||||
item_list.append(path)
|
||||
return item_list
|
||||
|
||||
|
||||
def get_patch_img(img, patch_size=96, scale=2):
|
||||
"""imagent"""
|
||||
ih, iw = img.shape[:2]
|
||||
tp = scale * patch_size
|
||||
if (iw - tp) > -1 and (ih-tp) > 1:
|
||||
ix = random.randrange(0, iw-tp+1)
|
||||
iy = random.randrange(0, ih-tp+1)
|
||||
hr = img[iy:iy+tp, ix:ix+tp, :3]
|
||||
elif (iw - tp) > -1 and (ih - tp) <= -1:
|
||||
ix = random.randrange(0, iw-tp+1)
|
||||
hr = img[:, ix:ix+tp, :3]
|
||||
pil_img = Image.fromarray(hr).resize((tp, tp), Image.BILINEAR)
|
||||
hr = np.array(pil_img)
|
||||
elif (iw - tp) <= -1 and (ih - tp) > -1:
|
||||
iy = random.randrange(0, ih-tp+1)
|
||||
hr = img[iy:iy+tp, :, :3]
|
||||
pil_img = Image.fromarray(hr).resize((tp, tp), Image.BILINEAR)
|
||||
hr = np.array(pil_img)
|
||||
else:
|
||||
pil_img = Image.fromarray(img).resize((tp, tp), Image.BILINEAR)
|
||||
hr = np.array(pil_img)
|
||||
return hr
|
||||
|
||||
|
||||
class ImgData():
|
||||
"""imagent"""
|
||||
def __init__(self, args, train=True):
|
||||
self.input_large = (args.model == 'VDSR')
|
||||
self.scale = args.scale
|
||||
self.idx_scale = 0
|
||||
self.dataroot = args.dir_data
|
||||
self.img_list = search(os.path.join(self.dataroot, "train"), "JPEG")
|
||||
self.img_list.extend(search(os.path.join(self.dataroot, "val"), "JPEG"))
|
||||
self.img_list = sorted(self.img_list)
|
||||
self.train = train
|
||||
self.args = args
|
||||
self.len = len(self.img_list)
|
||||
print("data length:", len(self.img_list))
|
||||
if self.args.derain:
|
||||
self.derain_dataroot = os.path.join(self.dataroot, "RainTrainL")
|
||||
self.derain_img_list = search(self.derain_dataroot, "rainstreak")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_list)
|
||||
|
||||
def _get_index(self, idx):
|
||||
return idx % len(self.img_list)
|
||||
|
||||
def _load_file(self, idx):
|
||||
idx = self._get_index(idx)
|
||||
f_lr = self.img_list[idx]
|
||||
lr = imageio.imread(f_lr)
|
||||
if len(lr.shape) == 2:
|
||||
lr = np.dstack([lr, lr, lr])
|
||||
return lr, f_lr
|
||||
|
||||
def _np2Tensor(self, img, rgb_range):
|
||||
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
|
||||
tensor = np_transpose.astype(np.float32)
|
||||
tensor = tensor * (rgb_range / 255)
|
||||
return tensor
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.args.model == 'vtip' and self.train and self.args.alltask:
|
||||
lr, filename = self._load_file(idx % self.len)
|
||||
pair_list = []
|
||||
rain = self._load_rain()
|
||||
rain = np.expand_dims(rain, axis=2)
|
||||
rain = self.get_patch(rain, 1)
|
||||
rain = self._np2Tensor(rain, rgb_range=self.args.rgb_range)
|
||||
for idx_scale in range(4):
|
||||
self.idx_scale = idx_scale
|
||||
pair = self.get_patch(lr)
|
||||
pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range)
|
||||
pair_list.append(pair_t)
|
||||
return pair_list[3], rain, pair_list[0], pair_list[1], pair_list[2], [self.scale], [filename]
|
||||
if self.args.model == 'vtip' and self.train and len(self.scale) > 1:
|
||||
lr, filename = self._load_file(idx % self.len)
|
||||
pair_list = []
|
||||
for idx_scale in range(3):
|
||||
self.idx_scale = idx_scale
|
||||
pair = self.get_patch(lr)
|
||||
pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range)
|
||||
pair_list.append(pair_t)
|
||||
return pair_list[0], pair_list[1], pair_list[2], filename
|
||||
if self.args.model == 'vtip' and self.args.derain and self.scale[self.idx_scale] == 1:
|
||||
lr, filename = self._load_file(idx % self.len)
|
||||
rain = self._load_rain()
|
||||
rain = np.expand_dims(rain, axis=2)
|
||||
rain = self.get_patch(rain, 1)
|
||||
rain = self._np2Tensor(rain, rgb_range=self.args.rgb_range)
|
||||
pair = self.get_patch(lr)
|
||||
pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range)
|
||||
return pair_t, rain, filename
|
||||
if self.args.jpeg:
|
||||
hr, filename = self._load_file(idx % self.len)
|
||||
buffer = io.BytesIO()
|
||||
width, height = hr.size
|
||||
patch_size = self.scale[self.idx_scale]*self.args.patch_size
|
||||
if width < patch_size:
|
||||
hr = hr.resize((patch_size, height), Image.ANTIALIAS)
|
||||
width, height = hr.size
|
||||
if height < patch_size:
|
||||
hr = hr.resize((width, patch_size), Image.ANTIALIAS)
|
||||
hr.save(buffer, format='jpeg', quality=25)
|
||||
lr = Image.open(buffer)
|
||||
lr = np.array(lr).astype(np.float32)
|
||||
hr = np.array(hr).astype(np.float32)
|
||||
lr = self.get_patch(lr)
|
||||
hr = self.get_patch(hr)
|
||||
lr = self._np2Tensor(lr, rgb_range=self.args.rgb_range)
|
||||
hr = self._np2Tensor(hr, rgb_range=self.args.rgb_range)
|
||||
return lr, hr, filename
|
||||
lr, filename = self._load_file(idx % self.len)
|
||||
pair = self.get_patch(lr)
|
||||
pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range)
|
||||
return pair_t, filename
|
||||
|
||||
def _load_rain(self):
|
||||
idx = random.randint(0, len(self.derain_img_list) - 1)
|
||||
f_lr = self.derain_img_list[idx]
|
||||
lr = imageio.imread(f_lr)
|
||||
return lr
|
||||
|
||||
def get_patch(self, lr, scale=0):
|
||||
if scale == 0:
|
||||
scale = self.scale[self.idx_scale]
|
||||
lr = get_patch_img(lr, patch_size=self.args.patch_size, scale=scale)
|
||||
return lr
|
||||
|
||||
def set_scale(self, idx_scale):
|
||||
self.idx_scale = idx_scale
|
|
@ -1,6 +1,6 @@
|
|||
"""srdata"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
@ -13,6 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
|
@ -20,43 +21,12 @@ import pickle
|
|||
import numpy as np
|
||||
import imageio
|
||||
from src.data import common
|
||||
from PIL import ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
def search(root, target="JPEG"):
|
||||
class SRData:
|
||||
"""srdata"""
|
||||
item_list = []
|
||||
items = os.listdir(root)
|
||||
for item in items:
|
||||
path = os.path.join(root, item)
|
||||
if os.path.isdir(path):
|
||||
item_list.extend(search(path, target))
|
||||
elif path.split('/')[-1].startswith(target):
|
||||
item_list.append(path)
|
||||
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
|
||||
item_list.append(path)
|
||||
else:
|
||||
item_list = []
|
||||
return item_list
|
||||
|
||||
|
||||
def search_dehaze(root, target="JPEG"):
|
||||
"""srdata"""
|
||||
item_list = []
|
||||
items = os.listdir(root)
|
||||
for item in items:
|
||||
path = os.path.join(root, item)
|
||||
if os.path.isdir(path):
|
||||
extend_list = search_dehaze(path, target)
|
||||
if extend_list is not None:
|
||||
item_list.extend(extend_list)
|
||||
elif path.split('/')[-2].endswith(target):
|
||||
item_list.append(path)
|
||||
return item_list
|
||||
|
||||
|
||||
class SRData():
|
||||
"""srdata"""
|
||||
|
||||
def __init__(self, args, name='', train=True, benchmark=False):
|
||||
self.args = args
|
||||
self.name = name
|
||||
|
@ -69,37 +39,46 @@ class SRData():
|
|||
self.idx_scale = 0
|
||||
|
||||
if self.args.derain:
|
||||
self.derain_test = os.path.join(args.dir_data, "Rain100L")
|
||||
self.derain_lr_test = search(self.derain_test, "rain")
|
||||
self.derain_hr_test = [path.replace(
|
||||
"rainy/", "no") for path in self.derain_lr_test]
|
||||
if self.train:
|
||||
self.derain_dataroot = os.path.join(args.dir_data, "RainTrainL")
|
||||
self.clear_train = common.search(self.derain_dataroot, "norain")
|
||||
self.rain_train = []
|
||||
for path in self.clear_train:
|
||||
change_path = path.split('/')
|
||||
change_path[-1] = change_path[-1][2:]
|
||||
self.rain_train.append('/'.join(change_path))
|
||||
self.derain_test = os.path.join(args.dir_data, "Rain100L")
|
||||
self.deblur_lr_test = common.search(self.derain_test, "rain")
|
||||
self.deblur_hr_test = [path.replace("rainy/", "no") for path in self.deblur_lr_test]
|
||||
self.derain_hr_test = self.deblur_hr_test
|
||||
else:
|
||||
self.derain_test = os.path.join(args.dir_data, "Rain100L")
|
||||
self.derain_lr_test = common.search(self.derain_test, "rain")
|
||||
self.derain_hr_test = [path.replace("rainy/", "no") for path in self.derain_lr_test]
|
||||
self._set_filesystem(args.dir_data)
|
||||
self._set_img(args)
|
||||
if self.args.derain and self.train:
|
||||
self.images_hr, self.images_lr = self.clear_train, self.rain_train
|
||||
if train:
|
||||
self._repeat(args)
|
||||
|
||||
def _set_img(self, args):
|
||||
"""srdata"""
|
||||
if args.ext.find('img') < 0:
|
||||
path_bin = os.path.join(self.apath, 'bin')
|
||||
os.makedirs(path_bin, exist_ok=True)
|
||||
|
||||
list_hr, list_lr = self._scan()
|
||||
if args.ext.find('img') >= 0 or benchmark:
|
||||
if args.ext.find('img') >= 0 or self.benchmark:
|
||||
self.images_hr, self.images_lr = list_hr, list_lr
|
||||
elif args.ext.find('sep') >= 0:
|
||||
os.makedirs(
|
||||
self.dir_hr.replace(self.apath, path_bin),
|
||||
exist_ok=True
|
||||
)
|
||||
os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True)
|
||||
for s in self.scale:
|
||||
if s == 1:
|
||||
os.makedirs(
|
||||
os.path.join(self.dir_hr),
|
||||
exist_ok=True
|
||||
)
|
||||
os.makedirs(os.path.join(self.dir_hr), exist_ok=True)
|
||||
else:
|
||||
os.makedirs(
|
||||
os.path.join(
|
||||
self.dir_lr.replace(self.apath, path_bin),
|
||||
'X{}'.format(s)
|
||||
),
|
||||
exist_ok=True
|
||||
)
|
||||
os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True)
|
||||
|
||||
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
|
||||
for h in list_hr:
|
||||
|
@ -114,23 +93,27 @@ class SRData():
|
|||
self.images_lr[i].append(b)
|
||||
self._check_and_load(args.ext, l, b, verbose=True)
|
||||
|
||||
# Below functions as used to prepare images
|
||||
def _repeat(self, args):
|
||||
"""srdata"""
|
||||
n_patches = args.batch_size * args.test_every
|
||||
n_images = len(args.data_train) * len(self.images_hr)
|
||||
if n_images == 0:
|
||||
self.repeat = 0
|
||||
else:
|
||||
self.repeat = max(n_patches // n_images, 1)
|
||||
|
||||
def _scan(self):
|
||||
"""srdata"""
|
||||
names_hr = sorted(
|
||||
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
|
||||
)
|
||||
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])))
|
||||
names_lr = [[] for _ in self.scale]
|
||||
for f in names_hr:
|
||||
filename, _ = os.path.splitext(os.path.basename(f))
|
||||
for si, s in enumerate(self.scale):
|
||||
if s != 1:
|
||||
scale = s
|
||||
names_lr[si].append(os.path.join(
|
||||
self.dir_lr, 'X{}/{}x{}{}'.format(
|
||||
s, filename, scale, self.ext[1]
|
||||
)
|
||||
))
|
||||
names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \
|
||||
.format(s, filename, scale, self.ext[1])))
|
||||
for si, s in enumerate(self.scale):
|
||||
if s == 1:
|
||||
names_lr[si] = names_hr
|
||||
|
@ -150,28 +133,33 @@ class SRData():
|
|||
pickle.dump(imageio.imread(img), _f)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.args.model == 'vtip' and self.args.derain and self.scale[
|
||||
self.idx_scale] == 1 and not self.args.finetune:
|
||||
norain, rain, _ = self._load_rain_test(idx)
|
||||
pair = common.set_channel(
|
||||
*[rain, norain], n_channels=self.args.n_colors)
|
||||
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
||||
return pair_t[0], pair_t[1]
|
||||
if self.args.model == 'vtip' and self.args.denoise and self.scale[self.idx_scale] == 1:
|
||||
hr, _ = self._load_file_hr(idx)
|
||||
if self.args.derain and self.scale[self.idx_scale] == 1:
|
||||
if self.train:
|
||||
lr, hr, filename = self._load_file_deblur(idx)
|
||||
pair = self.get_patch(lr, hr)
|
||||
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
|
||||
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
||||
else:
|
||||
norain, rain, filename = self._load_rain_test(idx)
|
||||
pair = common.set_channel(*[rain, norain], n_channels=self.args.n_colors)
|
||||
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
||||
return pair_t[0], pair_t[1], [self.idx_scale], [filename]
|
||||
|
||||
if self.args.denoise and self.scale[self.idx_scale] == 1:
|
||||
hr, filename = self._load_file_hr(idx)
|
||||
pair = self.get_patch_hr(hr)
|
||||
pair = common.set_channel(*[pair], n_channels=self.args.n_colors)
|
||||
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
||||
noise = np.random.randn(*pair_t[0].shape) * self.args.sigma
|
||||
lr = pair_t[0] + noise
|
||||
lr = np.float32(np.clip(lr, 0, 255))
|
||||
return lr, pair_t[0]
|
||||
lr, hr, _ = self._load_file(idx)
|
||||
return lr, pair_t[0], [self.idx_scale], [filename]
|
||||
lr, hr, filename = self._load_file(idx)
|
||||
pair = self.get_patch(lr, hr)
|
||||
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
|
||||
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
||||
|
||||
return pair_t[0], pair_t[1]
|
||||
return pair_t[0], pair_t[1], [self.idx_scale], [filename]
|
||||
|
||||
def __len__(self):
|
||||
if self.train:
|
||||
|
@ -182,7 +170,6 @@ class SRData():
|
|||
return len(self.images_hr)
|
||||
|
||||
def _get_index(self, idx):
|
||||
"""srdata"""
|
||||
if self.train:
|
||||
return idx % len(self.images_hr)
|
||||
return idx
|
||||
|
@ -198,22 +185,9 @@ class SRData():
|
|||
elif self.args.ext.find('sep') >= 0:
|
||||
with open(f_hr, 'rb') as _f:
|
||||
hr = pickle.load(_f)
|
||||
|
||||
return hr, filename
|
||||
|
||||
def _load_rain(self, idx, rain_img=False):
|
||||
"""srdata"""
|
||||
idx = random.randint(0, len(self.derain_img_list) - 1)
|
||||
f_lr = self.derain_img_list[idx]
|
||||
if rain_img:
|
||||
norain = imageio.imread(f_lr.replace("rainstreak", "norain"))
|
||||
rain = imageio.imread(f_lr.replace("rainstreak", "rain"))
|
||||
return norain, rain
|
||||
lr = imageio.imread(f_lr)
|
||||
return lr
|
||||
|
||||
def _load_rain_test(self, idx):
|
||||
"""srdata"""
|
||||
f_hr = self.derain_hr_test[idx]
|
||||
f_lr = self.derain_lr_test[idx]
|
||||
filename, _ = os.path.splitext(os.path.basename(f_lr))
|
||||
|
@ -221,14 +195,6 @@ class SRData():
|
|||
rain = imageio.imread(f_lr)
|
||||
return norain, rain, filename
|
||||
|
||||
def _load_denoise(self, idx):
|
||||
"""srdata"""
|
||||
idx = self._get_index(idx)
|
||||
f_lr = self.images_hr[idx]
|
||||
norain = imageio.imread(f_lr)
|
||||
rain = imageio.imread(f_lr.replace("HR", "LR_bicubic"))
|
||||
return norain, rain
|
||||
|
||||
def _load_file(self, idx):
|
||||
"""srdata"""
|
||||
idx = self._get_index(idx)
|
||||
|
@ -251,12 +217,7 @@ class SRData():
|
|||
def get_patch_hr(self, hr):
|
||||
"""srdata"""
|
||||
if self.train:
|
||||
hr = self.get_patch_img_hr(
|
||||
hr,
|
||||
patch_size=self.args.patch_size,
|
||||
scale=1
|
||||
)
|
||||
|
||||
hr = self.get_patch_img_hr(hr, patch_size=self.args.patch_size, scale=1)
|
||||
return hr
|
||||
|
||||
def get_patch_img_hr(self, img, patch_size=96, scale=2):
|
||||
|
@ -280,9 +241,7 @@ class SRData():
|
|||
lr, hr = common.get_patch(
|
||||
lr, hr,
|
||||
patch_size=self.args.patch_size * scale,
|
||||
scale=scale,
|
||||
multi=(len(self.scale) > 1)
|
||||
)
|
||||
scale=scale)
|
||||
if not self.args.no_augment:
|
||||
lr, hr = common.augment(lr, hr)
|
||||
else:
|
||||
|
@ -292,7 +251,6 @@ class SRData():
|
|||
return lr, hr
|
||||
|
||||
def set_scale(self, idx_scale):
|
||||
"""srdata"""
|
||||
if not self.input_large:
|
||||
self.idx_scale = idx_scale
|
||||
else:
|
||||
|
|
|
@ -1,241 +0,0 @@
|
|||
'''stride'''
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class _stride_unfold_(nn.Cell):
|
||||
'''stride'''
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
stride=-1):
|
||||
|
||||
super(_stride_unfold_, self).__init__()
|
||||
if stride == -1:
|
||||
self.stride = kernel_size
|
||||
else:
|
||||
self.stride = stride
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.unfold = _unfold_(kernel_size)
|
||||
|
||||
def construct(self, x):
|
||||
"""stride"""
|
||||
N, C, H, W = x.shape
|
||||
leftup_idx_x = []
|
||||
leftup_idx_y = []
|
||||
nh = int(H / self.stride)
|
||||
nw = int(W / self.stride)
|
||||
for i in range(nh):
|
||||
leftup_idx_x.append(i * self.stride)
|
||||
for i in range(nw):
|
||||
leftup_idx_y.append(i * self.stride)
|
||||
NumBlock_x = len(leftup_idx_x)
|
||||
NumBlock_y = len(leftup_idx_y)
|
||||
zeroslike = P.ZerosLike()
|
||||
cc_2 = P.Concat(axis=2)
|
||||
cc_3 = P.Concat(axis=3)
|
||||
unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size,
|
||||
NumBlock_y * self.kernel_size), mstype.float32)
|
||||
N, C, H, W = unf_x.shape
|
||||
for i in range(NumBlock_x):
|
||||
for j in range(NumBlock_y):
|
||||
unf_i = i * self.kernel_size
|
||||
unf_j = j * self.kernel_size
|
||||
org_i = leftup_idx_x[i]
|
||||
org_j = leftup_idx_y[j]
|
||||
fills = x[:, :, org_i:org_i + self.kernel_size,
|
||||
org_j:org_j + self.kernel_size]
|
||||
unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]), cc_2((cc_2(
|
||||
(zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)), zeroslike(
|
||||
unf_x[:, :, unf_i + self.kernel_size:, unf_j:unf_j + self.kernel_size]))))),
|
||||
zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:])))
|
||||
y = self.unfold(unf_x)
|
||||
return y
|
||||
|
||||
|
||||
class _stride_fold_(nn.Cell):
|
||||
'''stride'''
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
output_shape=(-1, -1),
|
||||
stride=-1):
|
||||
|
||||
super(_stride_fold_, self).__init__()
|
||||
if isinstance(kernel_size, (list, tuple)):
|
||||
self.kernel_size = kernel_size
|
||||
else:
|
||||
self.kernel_size = [kernel_size, kernel_size]
|
||||
|
||||
if stride == -1:
|
||||
self.stride = kernel_size[0]
|
||||
else:
|
||||
self.stride = stride
|
||||
|
||||
self.output_shape = output_shape
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.fold = _fold_(kernel_size)
|
||||
|
||||
def construct(self, x):
|
||||
'''stride'''
|
||||
if self.output_shape[0] == -1:
|
||||
large_x = self.fold(x)
|
||||
N, C, H, _ = large_x.shape
|
||||
leftup_idx = []
|
||||
for i in range(0, H, self.kernel_size[0]):
|
||||
leftup_idx.append(i)
|
||||
NumBlock = len(leftup_idx)
|
||||
fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0],
|
||||
(NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32)
|
||||
|
||||
for i in range(NumBlock):
|
||||
for j in range(NumBlock):
|
||||
fold_i = i * self.stride
|
||||
fold_j = j * self.stride
|
||||
org_i = leftup_idx[i]
|
||||
org_j = leftup_idx[j]
|
||||
fills = x[:, :, org_i:org_i + self.kernel_size[0],
|
||||
org_j:org_j + self.kernel_size[1]]
|
||||
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2(
|
||||
(zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(
|
||||
fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))),
|
||||
zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:])))
|
||||
y = fold_x
|
||||
else:
|
||||
NumBlock_x = int(
|
||||
(self.output_shape[0] - self.kernel_size[0]) / self.stride + 1)
|
||||
NumBlock_y = int(
|
||||
(self.output_shape[1] - self.kernel_size[1]) / self.stride + 1)
|
||||
large_shape = [NumBlock_x * self.kernel_size[0],
|
||||
NumBlock_y * self.kernel_size[1]]
|
||||
self.fold = _fold_(self.kernel_size, large_shape)
|
||||
large_x = self.fold(x)
|
||||
N, C, H, _ = large_x.shape
|
||||
leftup_idx_x = []
|
||||
leftup_idx_y = []
|
||||
for i in range(NumBlock_x):
|
||||
leftup_idx_x.append(i * self.kernel_size[0])
|
||||
for i in range(NumBlock_y):
|
||||
leftup_idx_y.append(i * self.kernel_size[1])
|
||||
fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0],
|
||||
(NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32)
|
||||
for i in range(NumBlock_x):
|
||||
for j in range(NumBlock_y):
|
||||
fold_i = i * self.stride
|
||||
fold_j = j * self.stride
|
||||
org_i = leftup_idx_x[i]
|
||||
org_j = leftup_idx_y[j]
|
||||
fills = x[:, :, org_i:org_i + self.kernel_size[0],
|
||||
org_j:org_j + self.kernel_size[1]]
|
||||
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2(
|
||||
(zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(
|
||||
fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))),
|
||||
zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:])))
|
||||
y = fold_x
|
||||
return y
|
||||
|
||||
|
||||
class _unfold_(nn.Cell):
|
||||
'''stride'''
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
stride=-1):
|
||||
|
||||
super(_unfold_, self).__init__()
|
||||
if stride == -1:
|
||||
self.stride = kernel_size
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, x):
|
||||
'''stride'''
|
||||
N, C, H, W = x.shape
|
||||
numH = int(H / self.kernel_size)
|
||||
numW = int(W / self.kernel_size)
|
||||
if numH * self.kernel_size != H or numW * self.kernel_size != W:
|
||||
x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size]
|
||||
output_img = self.reshape(x, (N, C, numH, self.kernel_size, W))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
|
||||
|
||||
output_img = self.reshape(output_img, (N, C, int(
|
||||
numH * numW), self.kernel_size, self.kernel_size))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 2, 1, 4, 3))
|
||||
|
||||
output_img = self.reshape(output_img, (N, int(numH * numW), -1))
|
||||
return output_img
|
||||
|
||||
|
||||
class _fold_(nn.Cell):
|
||||
'''stride'''
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
output_shape=(-1, -1),
|
||||
stride=-1):
|
||||
|
||||
super(_fold_, self).__init__()
|
||||
|
||||
if isinstance(kernel_size, (list, tuple)):
|
||||
self.kernel_size = kernel_size
|
||||
else:
|
||||
self.kernel_size = [kernel_size, kernel_size]
|
||||
|
||||
if stride == -1:
|
||||
self.stride = kernel_size[0]
|
||||
self.output_shape = output_shape
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, x):
|
||||
'''stride'''
|
||||
N, C, L = x.shape
|
||||
org_C = int(L / self.kernel_size[0] / self.kernel_size[1])
|
||||
if self.output_shape[0] == -1:
|
||||
numH = int(np.sqrt(C))
|
||||
numW = int(np.sqrt(C))
|
||||
org_H = int(numH * self.kernel_size[0])
|
||||
org_W = org_H
|
||||
else:
|
||||
org_H = int(self.output_shape[0])
|
||||
org_W = int(self.output_shape[1])
|
||||
numH = int(org_H / self.kernel_size[0])
|
||||
numW = int(org_W / self.kernel_size[1])
|
||||
|
||||
output_img = self.reshape(
|
||||
x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 2, 1, 3, 4))
|
||||
output_img = self.reshape(
|
||||
output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1]))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5))
|
||||
|
||||
output_img = self.reshape(output_img, (N, org_C, org_H, org_W))
|
||||
return output_img
|
341
model_zoo/research/cv/IPT/src/ipt.py → model_zoo/research/cv/IPT/src/ipt_model.py
Executable file → Normal file
341
model_zoo/research/cv/IPT/src/ipt.py → model_zoo/research/cv/IPT/src/ipt_model.py
Executable file → Normal file
|
@ -13,15 +13,30 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import math
|
||||
import copy
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import Tensor, Parameter
|
||||
|
||||
class LayerPreprocess(nn.Cell):
|
||||
"""
|
||||
Preprocess input of each layer
|
||||
"""
|
||||
def __init__(self, in_channels=None):
|
||||
super(LayerPreprocess, self).__init__()
|
||||
self.layernorm = nn.LayerNorm((in_channels,))
|
||||
self.cast = P.Cast()
|
||||
self.get_dtype = P.DType()
|
||||
|
||||
def construct(self, input_tensor):
|
||||
output = self.cast(input_tensor, mstype.float32)
|
||||
output = self.layernorm(output)
|
||||
output = self.cast(output, self.get_dtype(input_tensor))
|
||||
return output
|
||||
|
||||
class MultiheadAttention(nn.Cell):
|
||||
"""
|
||||
|
@ -45,7 +60,7 @@ class MultiheadAttention(nn.Cell):
|
|||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
|
||||
tensor. Default: False.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float16.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -64,13 +79,12 @@ class MultiheadAttention(nn.Cell):
|
|||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
do_return_2d_tensor=False,
|
||||
compute_type=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
same_dim=True):
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.size_per_head = int(hidden_width / num_attention_heads)
|
||||
self.has_attention_mask = has_attention_mask
|
||||
assert has_attention_mask
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.do_return_2d_tensor = do_return_2d_tensor
|
||||
|
@ -83,11 +97,9 @@ class MultiheadAttention(nn.Cell):
|
|||
self.shape_k_2d = (-1, k_tensor_width)
|
||||
self.shape_v_2d = (-1, v_tensor_width)
|
||||
self.hidden_width = int(hidden_width)
|
||||
# units = num_attention_heads * self.size_per_head
|
||||
if self.same_dim:
|
||||
self.in_proj_layer = \
|
||||
Parameter(Tensor(np.random.rand(hidden_width * 3,
|
||||
q_tensor_width), dtype=compute_type), name="weight")
|
||||
self.in_proj_layer = Parameter(Tensor(np.random.rand(hidden_width * 3,
|
||||
q_tensor_width), dtype=mstype.float32), name="weight")
|
||||
else:
|
||||
self.query_layer = nn.Dense(q_tensor_width,
|
||||
hidden_width,
|
||||
|
@ -132,8 +144,10 @@ class MultiheadAttention(nn.Cell):
|
|||
self.equal = P.Equal()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, tensor_q, tensor_k, tensor_v, attention_mask=None):
|
||||
"""Apply multihead attention."""
|
||||
def construct(self, tensor_q, tensor_k, tensor_v):
|
||||
"""
|
||||
Apply multihead attention.
|
||||
"""
|
||||
batch_size, seq_length, _ = self.shape(tensor_q)
|
||||
shape_qkv = (batch_size, -1,
|
||||
self.num_attention_heads, self.size_per_head)
|
||||
|
@ -161,20 +175,14 @@ class MultiheadAttention(nn.Cell):
|
|||
_start = 0
|
||||
_end = self.hidden_width
|
||||
_w = self.in_proj_layer[_start:_end, :]
|
||||
# _b = None
|
||||
query_out = self.matmul_dense(_w, tensor_q_2d)
|
||||
|
||||
_start = self.hidden_width
|
||||
_end = self.hidden_width * 2
|
||||
_w = self.in_proj_layer[_start:_end, :]
|
||||
# _b = None
|
||||
key_out = self.matmul_dense(_w, tensor_k_2d)
|
||||
|
||||
_start = self.hidden_width * 2
|
||||
|
||||
_end = None
|
||||
_w = self.in_proj_layer[_start:]
|
||||
# _b = None
|
||||
value_out = self.matmul_dense(_w, tensor_v_2d)
|
||||
else:
|
||||
query_out = self.query_layer(tensor_q_2d)
|
||||
|
@ -193,8 +201,7 @@ class MultiheadAttention(nn.Cell):
|
|||
|
||||
attention_scores = self.softmax_cast(attention_scores, mstype.float32)
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
attention_probs = self.softmax_cast(
|
||||
attention_probs, self.get_dtype(key_layer))
|
||||
attention_probs = self.softmax_cast(attention_probs, mstype.float16)
|
||||
if self.use_dropout:
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
|
@ -212,11 +219,8 @@ class MultiheadAttention(nn.Cell):
|
|||
|
||||
class TransformerEncoderLayer(nn.Cell):
|
||||
"""ipt"""
|
||||
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
||||
activation="relu"):
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, compute_type=mstype.float16):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = MultiheadAttention(q_tensor_width=d_model,
|
||||
k_tensor_width=d_model,
|
||||
v_tensor_width=d_model,
|
||||
|
@ -224,12 +228,12 @@ class TransformerEncoderLayer(nn.Cell):
|
|||
out_tensor_width=d_model,
|
||||
num_attention_heads=nhead,
|
||||
attention_probs_dropout_prob=dropout)
|
||||
self.linear1 = nn.Dense(d_model, dim_feedforward)
|
||||
self.linear1 = nn.Dense(d_model, dim_feedforward).to_float(compute_type)
|
||||
self.dropout = nn.Dropout(1. - dropout)
|
||||
self.linear2 = nn.Dense(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm([d_model])
|
||||
self.norm2 = nn.LayerNorm([d_model])
|
||||
self.norm1 = LayerPreprocess(d_model)
|
||||
self.norm2 = LayerPreprocess(d_model)
|
||||
self.dropout1 = nn.Dropout(1. - dropout)
|
||||
self.dropout2 = nn.Dropout(1. - dropout)
|
||||
self.reshape = P.Reshape()
|
||||
|
@ -237,7 +241,6 @@ class TransformerEncoderLayer(nn.Cell):
|
|||
self.activation = P.ReLU()
|
||||
|
||||
def with_pos_embed(self, tensor, pos):
|
||||
"""ipt"""
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def construct(self, src, pos=None):
|
||||
|
@ -258,10 +261,8 @@ class TransformerEncoderLayer(nn.Cell):
|
|||
|
||||
|
||||
class TransformerDecoderLayer(nn.Cell):
|
||||
"""ipt"""
|
||||
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
||||
activation="relu"):
|
||||
""" ipt"""
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
|
||||
super().__init__()
|
||||
self.self_attn = MultiheadAttention(q_tensor_width=d_model,
|
||||
k_tensor_width=d_model,
|
||||
|
@ -281,9 +282,9 @@ class TransformerDecoderLayer(nn.Cell):
|
|||
self.dropout = nn.Dropout(1. - dropout)
|
||||
self.linear2 = nn.Dense(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm([d_model])
|
||||
self.norm2 = nn.LayerNorm([d_model])
|
||||
self.norm3 = nn.LayerNorm([d_model])
|
||||
self.norm1 = LayerPreprocess(d_model)
|
||||
self.norm2 = LayerPreprocess(d_model)
|
||||
self.norm3 = LayerPreprocess(d_model)
|
||||
self.dropout1 = nn.Dropout(1. - dropout)
|
||||
self.dropout2 = nn.Dropout(1. - dropout)
|
||||
self.dropout3 = nn.Dropout(1. - dropout)
|
||||
|
@ -291,7 +292,6 @@ class TransformerDecoderLayer(nn.Cell):
|
|||
self.activation = P.ReLU()
|
||||
|
||||
def with_pos_embed(self, tensor, pos):
|
||||
"""ipt"""
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def construct(self, tgt, memory, pos=None, query_pos=None):
|
||||
|
@ -306,7 +306,7 @@ class TransformerDecoderLayer(nn.Cell):
|
|||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(tensor_q=self.with_pos_embed(tgt2, query_pos),
|
||||
tensor_k=self.with_pos_embed(memory, pos),
|
||||
tensor_v=memory,)
|
||||
tensor_v=memory)
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.reshape(tgt2, permute_linear)
|
||||
|
@ -318,47 +318,38 @@ class TransformerDecoderLayer(nn.Cell):
|
|||
|
||||
class TransformerEncoder(nn.Cell):
|
||||
"""ipt"""
|
||||
|
||||
def __init__(self, encoder_layer, num_layers):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def construct(self, src, pos=None):
|
||||
"""ipt"""
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output, pos=pos)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Cell):
|
||||
"""ipt"""
|
||||
|
||||
def __init__(self, decoder_layer, num_layers):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def construct(self, tgt, memory, pos=None, query_pos=None):
|
||||
"""ipt"""
|
||||
output = tgt
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output, memory, pos=pos, query_pos=query_pos)
|
||||
return output
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
"""ipt"""
|
||||
return nn.CellList([copy.deepcopy(module) for i in range(N)])
|
||||
def _get_clones(module, n):
|
||||
return nn.CellList([copy.deepcopy(module) for i in range(n)])
|
||||
|
||||
|
||||
class LearnedPositionalEncoding(nn.Cell):
|
||||
"""ipt"""
|
||||
|
||||
def __init__(self, max_position_embeddings, embedding_dim, seq_length):
|
||||
super(LearnedPositionalEncoding, self).__init__()
|
||||
self.pe = nn.Embedding(
|
||||
|
@ -370,8 +361,7 @@ class LearnedPositionalEncoding(nn.Cell):
|
|||
self.position_ids = self.reshape(
|
||||
self.position_ids, (1, self.seq_length))
|
||||
|
||||
def construct(self, x, position_ids=None):
|
||||
"""ipt"""
|
||||
def construct(self, position_ids=None):
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, : self.seq_length]
|
||||
|
||||
|
@ -381,46 +371,35 @@ class LearnedPositionalEncoding(nn.Cell):
|
|||
|
||||
class VisionTransformer(nn.Cell):
|
||||
"""ipt"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_dim,
|
||||
patch_dim,
|
||||
num_channels,
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
num_layers,
|
||||
hidden_dim,
|
||||
num_queries,
|
||||
idx,
|
||||
positional_encoding_type="learned",
|
||||
dropout_rate=0,
|
||||
norm=False,
|
||||
mlp=False,
|
||||
pos_every=False,
|
||||
no_pos=False
|
||||
):
|
||||
def __init__(self,
|
||||
img_dim,
|
||||
patch_dim,
|
||||
num_channels,
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
num_layers,
|
||||
hidden_dim,
|
||||
num_queries,
|
||||
dropout_rate=0,
|
||||
norm=False,
|
||||
mlp=False,
|
||||
pos_every=False,
|
||||
no_pos=False,
|
||||
con_loss=False):
|
||||
super(VisionTransformer, self).__init__()
|
||||
|
||||
assert embedding_dim % num_heads == 0
|
||||
assert img_dim % patch_dim == 0
|
||||
self.norm = norm
|
||||
self.mlp = mlp
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.patch_dim = patch_dim
|
||||
self.num_channels = num_channels
|
||||
|
||||
self.img_dim = img_dim
|
||||
self.pos_every = pos_every
|
||||
self.num_patches = int((img_dim // patch_dim) ** 2)
|
||||
self.seq_length = self.num_patches
|
||||
self.flatten_dim = patch_dim * patch_dim * num_channels
|
||||
|
||||
self.out_dim = patch_dim * patch_dim * num_channels
|
||||
|
||||
self.no_pos = no_pos
|
||||
|
||||
self.unf = _unfold_(patch_dim)
|
||||
self.fold = _fold_(patch_dim, output_shape=(img_dim, img_dim))
|
||||
|
||||
|
@ -432,8 +411,7 @@ class VisionTransformer(nn.Cell):
|
|||
nn.Dropout(1. - dropout_rate),
|
||||
nn.ReLU(),
|
||||
nn.Dense(hidden_dim, self.out_dim),
|
||||
nn.Dropout(1. - dropout_rate)
|
||||
)
|
||||
nn.Dropout(1. - dropout_rate))
|
||||
|
||||
self.query_embed = nn.Embedding(
|
||||
num_queries, embedding_dim * self.seq_length)
|
||||
|
@ -449,55 +427,54 @@ class VisionTransformer(nn.Cell):
|
|||
self.tile = P.Tile()
|
||||
self.transpose = P.Transpose()
|
||||
if not self.no_pos:
|
||||
self.position_encoding = LearnedPositionalEncoding(
|
||||
self.seq_length, self.embedding_dim, self.seq_length
|
||||
)
|
||||
self.position_encoding = LearnedPositionalEncoding(self.seq_length, self.embedding_dim, self.seq_length)
|
||||
|
||||
self.dropout_layer1 = nn.Dropout(1. - dropout_rate)
|
||||
self.query_idx = idx
|
||||
self.query_idx_tensor = Tensor(idx, mstype.int32)
|
||||
def construct(self, x):
|
||||
self.con_loss = con_loss
|
||||
|
||||
def construct(self, x, query_idx_tensor):
|
||||
"""ipt"""
|
||||
B, _, _, _ = x.shape
|
||||
x = self.unf(x)
|
||||
B, N, _ = x.shape
|
||||
b, n, _ = x.shape
|
||||
|
||||
if self.mlp is not True:
|
||||
x = self.reshape(x, (B * N, -1))
|
||||
x = self.reshape(x, (b * n, -1))
|
||||
x = self.dropout_layer1(self.linear_encoding(x)) + x
|
||||
x = self.reshape(x, (B, N, -1))
|
||||
x = self.reshape(x, (b, n, -1))
|
||||
query_embed = self.tile(
|
||||
self.reshape(self.query_embed(self.query_idx_tensor), (1, self.seq_length, self.embedding_dim)),
|
||||
(B, 1, 1))
|
||||
self.reshape(self.query_embed(query_idx_tensor), (1, self.seq_length, self.embedding_dim)), (b, 1, 1))
|
||||
|
||||
if not self.no_pos:
|
||||
pos = self.position_encoding(x)
|
||||
pos = self.position_encoding()
|
||||
x = self.encoder(x + pos)
|
||||
else:
|
||||
x = self.encoder(x)
|
||||
x = self.decoder(x, x, query_pos=query_embed)
|
||||
|
||||
if self.mlp is not True:
|
||||
x = self.reshape(x, (B * N, -1))
|
||||
x = self.reshape(x, (b * n, -1))
|
||||
x = self.mlp_head(x) + x
|
||||
x = self.reshape(x, (B, N, -1))
|
||||
x = self.reshape(x, (b, n, -1))
|
||||
if self.con_loss:
|
||||
con_x = x
|
||||
x = self.fold(x)
|
||||
return x, con_x
|
||||
x = self.fold(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def default_conv(in_channels, out_channels, kernel_size, has_bias=True):
|
||||
"""ipt"""
|
||||
return nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size, has_bias=has_bias)
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size, has_bias=has_bias)
|
||||
|
||||
|
||||
class MeanShift(nn.Conv2d):
|
||||
"""ipt"""
|
||||
|
||||
def __init__(
|
||||
self, rgb_range,
|
||||
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
|
||||
def __init__(self,
|
||||
rgb_range,
|
||||
rgb_mean=(0.4488, 0.4371, 0.4040),
|
||||
rgb_std=(1.0, 1.0, 1.0),
|
||||
sign=-1):
|
||||
super(MeanShift, self).__init__(3, 3, kernel_size=1)
|
||||
self.reshape = P.Reshape()
|
||||
self.eye = P.Eye()
|
||||
|
@ -512,10 +489,14 @@ class MeanShift(nn.Conv2d):
|
|||
|
||||
class ResBlock(nn.Cell):
|
||||
"""ipt"""
|
||||
|
||||
def __init__(
|
||||
self, conv, n_feats, kernel_size,
|
||||
bias=True, bn=False, act=nn.ReLU(), res_scale=1):
|
||||
def __init__(self,
|
||||
conv,
|
||||
n_feats,
|
||||
kernel_size,
|
||||
bias=True,
|
||||
bn=False,
|
||||
act=nn.ReLU(),
|
||||
res_scale=1):
|
||||
|
||||
super(ResBlock, self).__init__()
|
||||
m = []
|
||||
|
@ -532,35 +513,28 @@ class ResBlock(nn.Cell):
|
|||
self.mul = P.Mul()
|
||||
|
||||
def construct(self, x):
|
||||
"""ipt"""
|
||||
res = self.mul(self.body(x), self.res_scale)
|
||||
res += x
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _pixelsf_(x, scale):
|
||||
"""ipt"""
|
||||
N, C, iH, iW = x.shape
|
||||
oH = iH * scale
|
||||
oW = iW * scale
|
||||
oC = C // (scale ** 2)
|
||||
|
||||
output = P.Reshape()(x, (N, oC, scale, scale, iH, iW))
|
||||
|
||||
output = P.Transpose()(output, (0, 1, 5, 3, 4, 2))
|
||||
|
||||
output = P.Reshape()(output, (N, oC, oH, oW))
|
||||
|
||||
output = P.Transpose()(output, (0, 1, 3, 2))
|
||||
|
||||
n, c, ih, iw = x.shape
|
||||
oh = ih * scale
|
||||
ow = iw * scale
|
||||
oc = c // (scale ** 2)
|
||||
output = P.Transpose()(x, (0, 2, 1, 3))
|
||||
output = P.Reshape()(output, (n, ih, oc*scale, scale, iw))
|
||||
output = P.Transpose()(output, (0, 1, 2, 4, 3))
|
||||
output = P.Reshape()(output, (n, ih, oc, scale, ow))
|
||||
output = P.Transpose()(output, (0, 2, 1, 3, 4))
|
||||
output = P.Reshape()(output, (n, oc, oh, ow))
|
||||
return output
|
||||
|
||||
|
||||
class SmallUpSampler(nn.Cell):
|
||||
"""ipt"""
|
||||
|
||||
def __init__(self, conv, upsize, n_feats, bn=False, act=False, bias=True):
|
||||
def __init__(self, conv, upsize, n_feats, bias=True):
|
||||
super(SmallUpSampler, self).__init__()
|
||||
self.conv = conv(n_feats, upsize * upsize * n_feats, 3, bias)
|
||||
self.reshape = P.Reshape()
|
||||
|
@ -568,7 +542,6 @@ class SmallUpSampler(nn.Cell):
|
|||
self.pixelsf = _pixelsf_
|
||||
|
||||
def construct(self, x):
|
||||
"""ipt"""
|
||||
x = self.conv(x)
|
||||
output = self.pixelsf(x, self.upsize)
|
||||
return output
|
||||
|
@ -576,47 +549,37 @@ class SmallUpSampler(nn.Cell):
|
|||
|
||||
class Upsampler(nn.Cell):
|
||||
"""ipt"""
|
||||
|
||||
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
|
||||
def __init__(self, conv, scale, n_feats, bias=True):
|
||||
super(Upsampler, self).__init__()
|
||||
m = []
|
||||
if (scale & (scale - 1)) == 0:
|
||||
for _ in range(int(math.log(scale, 2))):
|
||||
m.append(SmallUpSampler(conv, 2, n_feats, bias=bias))
|
||||
|
||||
elif scale == 3:
|
||||
m.append(SmallUpSampler(conv, 3, n_feats, bias=bias))
|
||||
self.net = nn.SequentialCell(m)
|
||||
|
||||
def construct(self, x):
|
||||
"""ipt"""
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class IPT(nn.Cell):
|
||||
"""ipt"""
|
||||
|
||||
def __init__(self, args, conv=default_conv):
|
||||
super(IPT, self).__init__()
|
||||
|
||||
self.dytpe = mstype.float16
|
||||
self.scale_idx = 0
|
||||
|
||||
self.args = args
|
||||
|
||||
self.con_loss = args.con_loss
|
||||
n_feats = args.n_feats
|
||||
kernel_size = 3
|
||||
act = nn.ReLU()
|
||||
|
||||
self.sub_mean = MeanShift(args.rgb_range)
|
||||
self.add_mean = MeanShift(args.rgb_range, sign=1)
|
||||
|
||||
self.head = nn.CellList([
|
||||
nn.SequentialCell(
|
||||
conv(args.n_colors, n_feats, kernel_size),
|
||||
ResBlock(conv, n_feats, 5, act=act),
|
||||
ResBlock(conv, n_feats, 5, act=act)
|
||||
) for _ in args.scale
|
||||
])
|
||||
nn.SequentialCell(conv(args.n_colors, n_feats, kernel_size).to_float(self.dytpe),
|
||||
ResBlock(conv, n_feats, 5, act=act).to_float(self.dytpe),
|
||||
ResBlock(conv, n_feats, 5, act=act).to_float(self.dytpe)) for _ in range(6)])
|
||||
|
||||
self.body = VisionTransformer(img_dim=args.patch_size,
|
||||
patch_dim=args.patch_dim,
|
||||
|
@ -630,36 +593,34 @@ class IPT(nn.Cell):
|
|||
mlp=args.no_mlp,
|
||||
pos_every=args.pos_every,
|
||||
no_pos=args.no_pos,
|
||||
idx=self.scale_idx)
|
||||
con_loss=args.con_loss).to_float(self.dytpe)
|
||||
|
||||
self.tail = nn.CellList([
|
||||
nn.SequentialCell(
|
||||
Upsampler(conv, s, n_feats, act=False),
|
||||
conv(n_feats, args.n_colors, kernel_size)
|
||||
) for s in args.scale
|
||||
])
|
||||
nn.SequentialCell(Upsampler(conv, s, n_feats).to_float(self.dytpe),
|
||||
conv(n_feats, args.n_colors, kernel_size).to_float(self.dytpe)) \
|
||||
for s in [2, 3, 4, 1, 1, 1]])
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.tile = P.Tile()
|
||||
self.transpose = P.Transpose()
|
||||
self.s2t = P.ScalarToTensor()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
def construct(self, x, idx):
|
||||
"""ipt"""
|
||||
x = self.sub_mean(x)
|
||||
x = self.head[self.scale_idx](x)
|
||||
res = self.body(x)
|
||||
idx_num = idx.shape[0]
|
||||
x = self.head[idx_num](x)
|
||||
idx_tensor = self.cast(self.s2t(idx_num), mstype.int32)
|
||||
if self.con_loss:
|
||||
res, x_con = self.body(x, idx_tensor)
|
||||
res += x
|
||||
x = self.tail[idx_num](x)
|
||||
return x, x_con
|
||||
res = self.body(x, idx_tensor)
|
||||
res += x
|
||||
x = self.tail[self.scale_idx](res)
|
||||
x = self.add_mean(x)
|
||||
|
||||
x = self.tail[idx_num](res)
|
||||
return x
|
||||
|
||||
def set_scale(self, scale_idx):
|
||||
"""ipt"""
|
||||
self.body.query_idx = scale_idx
|
||||
self.scale_idx = scale_idx
|
||||
|
||||
|
||||
class IPT_post():
|
||||
"""ipt"""
|
||||
def __init__(self, model, args):
|
||||
|
@ -674,17 +635,13 @@ class IPT_post():
|
|||
self.cc_2 = P.Concat(axis=2)
|
||||
self.cc_3 = P.Concat(axis=3)
|
||||
|
||||
def set_scale(self, scale_idx):
|
||||
"""ipt"""
|
||||
self.body.query_idx = scale_idx
|
||||
self.scale_idx = scale_idx
|
||||
|
||||
def forward(self, x, shave=12, batchsize=64):
|
||||
def forward(self, x, idx, shave=12, batchsize=64):
|
||||
"""ipt"""
|
||||
self.idx = idx
|
||||
h, w = x.shape[-2:]
|
||||
padsize = int(self.args.patch_size)
|
||||
shave = int(self.args.patch_size / 4)
|
||||
scale = self.args.scale[self.scale_idx]
|
||||
scale = self.args.scale[0]
|
||||
h_cut = (h - padsize) % (padsize - shave)
|
||||
w_cut = (w - padsize) % (padsize - shave)
|
||||
|
||||
|
@ -692,7 +649,7 @@ class IPT_post():
|
|||
x_unfold = unf_1.compute(x)
|
||||
x_unfold = self.transpose(x_unfold, (1, 0, 2)) # transpose(0,2)
|
||||
x_hw_cut = x[:, :, (h - padsize):, (w - padsize):]
|
||||
y_hw_cut = self.model(x_hw_cut)
|
||||
y_hw_cut = self.model(x_hw_cut, self.idx)
|
||||
|
||||
x_h_cut = x[:, :, (h - padsize):, :]
|
||||
x_w_cut = x[:, :, :, (w - padsize):]
|
||||
|
@ -714,10 +671,10 @@ class IPT_post():
|
|||
for i in range(x_range):
|
||||
if i == 0:
|
||||
y_unfold = self.model(
|
||||
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
|
||||
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)
|
||||
else:
|
||||
y_unfold = self.cc_0((y_unfold, self.model(
|
||||
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
|
||||
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)))
|
||||
y_unf_shape_0 = y_unfold.shape[0]
|
||||
fold_1 = \
|
||||
_stride_fold_(padsize * scale, output_shape=((h - h_cut) * scale, (w - w_cut) * scale),
|
||||
|
@ -740,17 +697,18 @@ class IPT_post():
|
|||
stride=padsize * scale - shave * scale)
|
||||
y_inter = fold_2.compute(self.transpose(self.reshape(
|
||||
y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1)))
|
||||
concat1 = self.cc_2((y[:, :, :int(shave / 2 * scale), int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter)) #pylint: disable=line-too-long
|
||||
concat2 = self.cc_2((concat1, y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)])) #pylint: disable=line-too-long
|
||||
concat1 = self.cc_2((y[:, :, :int(shave / 2 * scale), \
|
||||
int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter))
|
||||
concat2 = self.cc_2((concat1, y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, \
|
||||
int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)]))
|
||||
concat3 = self.cc_3((y[:, :, :, :int(shave / 2 * scale)], concat2))
|
||||
y = self.cc_3((concat3, y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) #pylint: disable=line-too-long
|
||||
y = self.cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :], y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) #pylint: disable=line-too-long
|
||||
|
||||
y = self.cc_3((concat3, y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):]))
|
||||
y = self.cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :],
|
||||
y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :]))
|
||||
y_w_cat = self.cc_2((y_w_cut[:, :, :y_w_cut.shape[2] - int((padsize - h_cut) / 2 * scale), :],
|
||||
y_hw_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :]))
|
||||
y = self.cc_3((y[:, :, :, :y.shape[3] - int((padsize - w_cut) / 2 * scale)],
|
||||
y_w_cat[:, :, :, int((padsize - w_cut) / 2 * scale + 0.5):]))
|
||||
|
||||
return y
|
||||
|
||||
def cut_h_new(self, x_h_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize):
|
||||
|
@ -766,11 +724,11 @@ class IPT_post():
|
|||
for i in range(x_range):
|
||||
if i == 0:
|
||||
y_h_cut_unfold = self.model(
|
||||
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
|
||||
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)
|
||||
else:
|
||||
y_h_cut_unfold = \
|
||||
self.cc_0((y_h_cut_unfold, self.model(
|
||||
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
|
||||
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)))
|
||||
y_h_cut_unfold_shape_0 = y_h_cut_unfold.shape[0]
|
||||
fold_1 = \
|
||||
_stride_fold_(padsize * scale, output_shape=(padsize * scale, (w - w_cut) * scale),
|
||||
|
@ -802,10 +760,11 @@ class IPT_post():
|
|||
for i in range(x_range):
|
||||
if i == 0:
|
||||
y_w_cut_unfold = self.model(
|
||||
x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
|
||||
x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)
|
||||
else:
|
||||
y_w_cut_unfold = self.cc_0((y_w_cut_unfold,
|
||||
self.model(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
|
||||
self.model(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], \
|
||||
self.idx)))
|
||||
y_w_cut_unfold_shape_0 = y_w_cut_unfold.shape[0]
|
||||
fold_1 = _stride_fold_(padsize * scale,
|
||||
output_shape=((h - h_cut) * scale,
|
||||
|
@ -827,7 +786,6 @@ class IPT_post():
|
|||
|
||||
class _stride_unfold_():
|
||||
'''stride'''
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
stride=-1):
|
||||
|
@ -874,13 +832,12 @@ class _stride_unfold_():
|
|||
zeros4 = np.zeros(unf_x[:, :, :, unf_j + self.kernel_size:].shape)
|
||||
concat4 = np.concatenate((concat3, zeros4), axis=3)
|
||||
unf_x += concat4
|
||||
unf_x = Tensor(unf_x, mstype.float32)
|
||||
unf_x = Tensor(unf_x, mstype.float16)
|
||||
y = self.unfold(unf_x)
|
||||
return y
|
||||
|
||||
class _stride_fold_():
|
||||
'''stride'''
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
output_shape=(-1, -1),
|
||||
|
@ -905,7 +862,7 @@ class _stride_fold_():
|
|||
self.fold = _fold_(self.kernel_size, self.large_shape)
|
||||
|
||||
def compute(self, x):
|
||||
'''stride'''
|
||||
""" compute"""
|
||||
NumBlock_x = self.NumBlock_x
|
||||
NumBlock_y = self.NumBlock_y
|
||||
large_x = self.fold(x)
|
||||
|
@ -917,7 +874,8 @@ class _stride_fold_():
|
|||
leftup_idx_x.append(i * self.kernel_size[0])
|
||||
for i in range(NumBlock_y):
|
||||
leftup_idx_y.append(i * self.kernel_size[1])
|
||||
fold_x = np.zeros((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], (NumBlock_y - 1) * self.stride + self.kernel_size[1]), dtype=np.float32) #pylint: disable=line-too-long
|
||||
fold_x = np.zeros((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], \
|
||||
(NumBlock_y - 1) * self.stride + self.kernel_size[1]), dtype=np.float32)
|
||||
for i in range(NumBlock_x):
|
||||
for j in range(NumBlock_y):
|
||||
fold_i = i * self.stride
|
||||
|
@ -938,12 +896,11 @@ class _stride_fold_():
|
|||
zeros4 = np.zeros(t4.shape)
|
||||
concat4 = np.concatenate((concat3, zeros4), axis=3)
|
||||
fold_x += concat4
|
||||
y = Tensor(fold_x, mstype.float32)
|
||||
y = Tensor(fold_x, mstype.float16)
|
||||
return y
|
||||
|
||||
class _unfold_(nn.Cell):
|
||||
"""ipt"""
|
||||
|
||||
def __init__(
|
||||
self, kernel_size, stride=-1):
|
||||
|
||||
|
@ -965,8 +922,10 @@ class _unfold_(nn.Cell):
|
|||
output_img = self.reshape(x, (N, C, numH, self.kernel_size, W))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
|
||||
output_img = self.reshape(output_img, (N, C, numH, -1, self.kernel_size, self.kernel_size))
|
||||
output_img = self.transpose(output_img, (0, 2, 3, 1, 5, 4))
|
||||
output_img = self.reshape(output_img, (N*C, numH, numW, self.kernel_size, self.kernel_size))
|
||||
output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
|
||||
output_img = self.reshape(output_img, (N, C, numH * numW, self.kernel_size*self.kernel_size))
|
||||
output_img = self.transpose(output_img, (0, 2, 1, 3))
|
||||
output_img = self.reshape(output_img, (N, numH * numW, -1))
|
||||
return output_img
|
||||
|
||||
|
@ -1002,14 +961,10 @@ class _fold_(nn.Cell):
|
|||
org_W = self.output_shape[1]
|
||||
numH = org_H // self.kernel_size[0]
|
||||
numW = org_W // self.kernel_size[1]
|
||||
output_img = self.reshape(
|
||||
x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))
|
||||
|
||||
output_img = self.reshape(x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))
|
||||
output_img = self.transpose(output_img, (0, 2, 3, 1, 4))
|
||||
output_img = self.reshape(output_img, (N*org_C, self.kernel_size[0], numH, numW, self.kernel_size[1]))
|
||||
output_img = self.transpose(output_img, (0, 2, 1, 3, 4))
|
||||
output_img = self.reshape(
|
||||
output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1]))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5))
|
||||
|
||||
output_img = self.reshape(output_img, (N, org_C, org_H, org_W))
|
||||
return output_img
|
|
@ -0,0 +1,125 @@
|
|||
"""loss"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
class SupConLoss(nn.Cell):
|
||||
"""SupConLoss"""
|
||||
def __init__(self, temperature=0.07, contrast_mode='all',
|
||||
base_temperature=0.07):
|
||||
super(SupConLoss, self).__init__()
|
||||
self.temperature = temperature
|
||||
self.contrast_mode = contrast_mode
|
||||
self.base_temperature = base_temperature
|
||||
self.normalize = P.L2Normalize(axis=2)
|
||||
self.eye = P.Eye()
|
||||
self.unbind = P.Unstack(axis=1)
|
||||
self.cat = P.Concat(axis=0)
|
||||
self.matmul = P.MatMul()
|
||||
self.div = P.Div()
|
||||
self.transpose = P.Transpose()
|
||||
self.maxes = P.ArgMaxWithValue(axis=1, keep_dims=True)
|
||||
self.tile = P.Tile()
|
||||
self.scatter = P.ScatterNd()
|
||||
self.oneslike = P.OnesLike()
|
||||
self.exp = P.Exp()
|
||||
self.sum = P.ReduceSum(keep_dims=True)
|
||||
self.log = P.Log()
|
||||
self.reshape = P.Reshape()
|
||||
self.mean = P.ReduceMean()
|
||||
|
||||
def construct(self, features):
|
||||
"""SupConLoss"""
|
||||
features = self.normalize(features)
|
||||
batch_size = features.shape[0]
|
||||
mask = self.eye(batch_size, batch_size, mstype.float32)
|
||||
contrast_count = features.shape[1]
|
||||
contrast_feature = self.cat(self.unbind(features))
|
||||
if self.contrast_mode == 'all':
|
||||
anchor_feature = contrast_feature
|
||||
anchor_count = contrast_count
|
||||
else:
|
||||
anchor_feature = features[:, 0]
|
||||
anchor_count = 1
|
||||
anchor_dot_contrast = self.div(self.matmul(anchor_feature, self.transpose(contrast_feature, (1, 0))), \
|
||||
self.temperature)
|
||||
_, logits_max = self.maxes(anchor_dot_contrast)
|
||||
logits = anchor_dot_contrast - logits_max
|
||||
mask = self.tile(mask, (anchor_count, contrast_count))
|
||||
logits_mask = 1 - self.eye(mask.shape[0], mask.shape[1], mstype.float32)
|
||||
mask = mask * logits_mask
|
||||
exp_logits = self.exp(logits) * logits_mask
|
||||
log_prob = logits - self.log(self.sum(exp_logits, 1) + 1e-8)
|
||||
mean_log_prob_pos = self.sum((mask * log_prob), 1) / self.sum(mask, 1)
|
||||
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
|
||||
loss = self.mean(self.reshape(loss, (anchor_count, batch_size)))
|
||||
return loss, anchor_count
|
||||
|
||||
|
||||
class ClipGradients(nn.Cell):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Args:
|
||||
grads (list): List of gradient tuples.
|
||||
clip_type (Tensor): The way to clip, 'value' or 'norm'.
|
||||
clip_value (Tensor): Specifies how much to clip.
|
||||
|
||||
Returns:
|
||||
List, a list of clipped_grad tuples.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ClipGradients, self).__init__()
|
||||
self.clip_by_norm = nn.ClipByNorm()
|
||||
self.cast = P.Cast()
|
||||
self.dtype = P.DType()
|
||||
|
||||
def construct(self, grads, clip_type, clip_value):
|
||||
"""ClipGradients"""
|
||||
if clip_type not in (0, 1):
|
||||
return grads
|
||||
new_grads = ()
|
||||
for grad in grads:
|
||||
dt = self.dtype(grad)
|
||||
if clip_type == 0:
|
||||
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
t = self.cast(t, dt)
|
||||
new_grads = new_grads + (t,)
|
||||
return new_grads
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
|
@ -1,4 +1,4 @@
|
|||
'''metrics'''
|
||||
"""metrics"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -13,12 +13,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def quantize(img, rgb_range):
|
||||
'''metrics'''
|
||||
"""metrics"""
|
||||
pixel_range = 255 / rgb_range
|
||||
img = np.multiply(img, pixel_range)
|
||||
img = np.clip(img, 0, 255)
|
||||
|
@ -26,15 +26,14 @@ def quantize(img, rgb_range):
|
|||
return img
|
||||
|
||||
|
||||
def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None):
|
||||
'''metrics'''
|
||||
def calc_psnr(sr, hr, scale, rgb_range):
|
||||
"""metrics"""
|
||||
hr = np.float32(hr)
|
||||
sr = np.float32(sr)
|
||||
diff = (sr - hr) / rgb_range
|
||||
gray_coeffs = np.array([65.738, 129.057, 25.064]
|
||||
).reshape((1, 3, 1, 1)) / 256
|
||||
gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256
|
||||
diff = np.multiply(diff, gray_coeffs).sum(1)
|
||||
if np.size(hr) == 1:
|
||||
if hr.size == 1:
|
||||
return 0
|
||||
if scale != 1:
|
||||
shave = scale
|
||||
|
@ -49,7 +48,7 @@ def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None):
|
|||
|
||||
|
||||
def rgb2ycbcr(img, y_only=True):
|
||||
'''metrics'''
|
||||
"""metrics"""
|
||||
img.astype(np.float32)
|
||||
if y_only:
|
||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
||||
|
|
|
@ -1,67 +0,0 @@
|
|||
'''temp'''
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
def set_template(args):
|
||||
'''temp'''
|
||||
if args.template.find('jpeg') >= 0:
|
||||
args.data_train = 'DIV2K_jpeg'
|
||||
args.data_test = 'DIV2K_jpeg'
|
||||
args.epochs = 200
|
||||
args.decay = '100'
|
||||
|
||||
if args.template.find('EDSR_paper') >= 0:
|
||||
args.model = 'EDSR'
|
||||
args.n_resblocks = 32
|
||||
args.n_feats = 256
|
||||
args.res_scale = 0.1
|
||||
|
||||
if args.template.find('MDSR') >= 0:
|
||||
args.model = 'MDSR'
|
||||
args.patch_size = 48
|
||||
args.epochs = 650
|
||||
|
||||
if args.template.find('DDBPN') >= 0:
|
||||
args.model = 'DDBPN'
|
||||
args.patch_size = 128
|
||||
args.scale = '4'
|
||||
|
||||
args.data_test = 'Set5'
|
||||
|
||||
args.batch_size = 20
|
||||
args.epochs = 1000
|
||||
args.decay = '500'
|
||||
args.gamma = 0.1
|
||||
args.weight_decay = 1e-4
|
||||
|
||||
args.loss = '1*MSE'
|
||||
|
||||
if args.template.find('GAN') >= 0:
|
||||
args.epochs = 200
|
||||
args.lr = 5e-5
|
||||
args.decay = '150'
|
||||
|
||||
if args.template.find('RCAN') >= 0:
|
||||
args.model = 'RCAN'
|
||||
args.n_resgroups = 10
|
||||
args.n_resblocks = 20
|
||||
args.n_feats = 64
|
||||
args.chop = True
|
||||
|
||||
if args.template.find('VDSR') >= 0:
|
||||
args.model = 'VDSR'
|
||||
args.n_resblocks = 20
|
||||
args.n_feats = 64
|
||||
args.patch_size = 41
|
||||
args.lr = 1e-1
|
|
@ -0,0 +1,173 @@
|
|||
"""utils"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import time
|
||||
from bisect import bisect_right
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.parallel._utils import _get_parallel_mode
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
from src.loss import SupConLoss
|
||||
|
||||
|
||||
class MyTrain(nn.Cell):
|
||||
"""MyTrain"""
|
||||
def __init__(self, model, criterion, con_loss, use_con=True):
|
||||
super(MyTrain, self).__init__(auto_prefix=True)
|
||||
self.use_con = use_con
|
||||
self.model = model
|
||||
self.con_loss = con_loss
|
||||
self.criterion = criterion
|
||||
self.p = P.Print()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, lr, hr, idx):
|
||||
"""MyTrain"""
|
||||
if self.use_con:
|
||||
sr, x_con = self.model(lr, idx)
|
||||
x_con = self.cast(x_con, mstype.float32)
|
||||
sr = self.cast(sr, mstype.float32)
|
||||
loss1 = self.criterion(sr, hr)
|
||||
loss2 = self.con_loss(x_con)
|
||||
loss = loss1 + 0.1 * loss2
|
||||
else:
|
||||
sr = self.model(lr, idx)
|
||||
sr = self.cast(sr, mstype.float32)
|
||||
loss = self.criterion(sr, hr)
|
||||
return loss
|
||||
|
||||
|
||||
class MyTrainOneStepCell(nn.Cell):
|
||||
"""MyTrainOneStepCell"""
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(MyTrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
parallel_mode = _get_parallel_mode()
|
||||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, True, 8)
|
||||
|
||||
def construct(self, *args):
|
||||
weights = self.weights
|
||||
loss = self.network(*args)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(*args, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
def sub_mean(x):
|
||||
red_channel_mean = 0.4488 * 255
|
||||
green_channel_mean = 0.4371 * 255
|
||||
blue_channel_mean = 0.4040 * 255
|
||||
x[:, 0, :, :] -= red_channel_mean
|
||||
x[:, 1, :, :] -= green_channel_mean
|
||||
x[:, 2, :, :] -= blue_channel_mean
|
||||
return x
|
||||
|
||||
|
||||
def add_mean(x):
|
||||
red_channel_mean = 0.4488 * 255
|
||||
green_channel_mean = 0.4371 * 255
|
||||
blue_channel_mean = 0.4040 * 255
|
||||
x[:, 0, :, :] += red_channel_mean
|
||||
x[:, 1, :, :] += green_channel_mean
|
||||
x[:, 2, :, :] += blue_channel_mean
|
||||
return x
|
||||
|
||||
|
||||
class Trainer():
|
||||
"""Trainer"""
|
||||
def __init__(self, args, loader, my_model):
|
||||
self.args = args
|
||||
self.scale = args.scale
|
||||
self.trainloader = loader
|
||||
self.model = my_model
|
||||
self.model.set_train()
|
||||
self.criterion = nn.L1Loss()
|
||||
self.con_loss = SupConLoss()
|
||||
self.optimizer = nn.Adam(self.model.trainable_params(), learning_rate=args.lr, loss_scale=1024.0)
|
||||
self.train_net = MyTrain(self.model, self.criterion, self.con_loss, use_con=args.con_loss)
|
||||
self.bp = MyTrainOneStepCell(self.train_net, self.optimizer, 1024.0)
|
||||
|
||||
def train(self):
|
||||
"""Trainer"""
|
||||
losses = 0
|
||||
for batch_idx, imgs in enumerate(self.trainloader):
|
||||
lr = imgs["LR"]
|
||||
hr = imgs["HR"]
|
||||
lr = Tensor(sub_mean(lr), mstype.float32)
|
||||
hr = Tensor(sub_mean(hr), mstype.float32)
|
||||
idx = Tensor(np.ones(imgs["idx"][0]), mstype.int32)
|
||||
t1 = time.time()
|
||||
loss = self.bp(lr, hr, idx)
|
||||
t2 = time.time()
|
||||
losses += loss.asnumpy()
|
||||
print('Task: %g, Step: %g, loss: %f, time: %f s' % (idx.shape[0], batch_idx, loss.asnumpy(), t2 - t1),
|
||||
flush=True)
|
||||
os.makedirs(self.args.save, exist_ok=True)
|
||||
if self.args.rank == 0:
|
||||
save_checkpoint(self.bp, self.args.save + "model_" + str(self.epoch) + '.ckpt')
|
||||
|
||||
def update_learning_rate(self, epoch):
|
||||
"""Update learning rates for all the networks; called at the end of every epoch.
|
||||
:param epoch: current epoch
|
||||
:type epoch: int
|
||||
:param lr: learning rate of cyclegan
|
||||
:type lr: float
|
||||
:param niter: number of epochs with the initial learning rate
|
||||
:type niter: int
|
||||
:param niter_decay: number of epochs to linearly decay learning rate to zero
|
||||
:type niter_decay: int
|
||||
"""
|
||||
self.epoch = epoch
|
||||
value = self.args.decay.split('-')
|
||||
value.sort(key=int)
|
||||
milestones = list(map(int, value))
|
||||
print("*********** epoch: {} **********".format(epoch))
|
||||
lr = self.args.lr * self.args.gamma ** bisect_right(milestones, epoch)
|
||||
self.adjust_lr('model', self.optimizer, lr)
|
||||
print("*********************************")
|
||||
|
||||
def adjust_lr(self, name, optimizer, lr):
|
||||
"""Adjust learning rate for the corresponding model.
|
||||
:param name: name of model
|
||||
:type name: str
|
||||
:param optimizer: the optimizer of the corresponding model
|
||||
:type optimizer: torch.optim
|
||||
:param lr: learning rate to be adjusted
|
||||
:type lr: float
|
||||
"""
|
||||
lr_param = optimizer.get_lr()
|
||||
lr_param.assign_value(Tensor(lr, mstype.float32))
|
||||
print('==> ' + name + ' learning rate: ', lr_param.asnumpy())
|
|
@ -0,0 +1,154 @@
|
|||
"""train"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import math
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import Parameter, set_seed, context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.common.initializer import initializer, HeUniform, XavierUniform, Uniform, Normal, Zero
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.args import args
|
||||
from src.data.bicubic import bicubic
|
||||
from src.data.imagenet import ImgData
|
||||
from src.ipt_model import IPT
|
||||
from src.utils import Trainer
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(shape):
|
||||
"""
|
||||
calculate fan_in and fan_out
|
||||
|
||||
Args:
|
||||
shape (tuple): input shape.
|
||||
|
||||
Returns:
|
||||
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
|
||||
"""
|
||||
dimensions = len(shape)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
||||
if dimensions == 2:
|
||||
fan_in = shape[1]
|
||||
fan_out = shape[0]
|
||||
else:
|
||||
num_input_fmaps = shape[1]
|
||||
num_output_fmaps = shape[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
receptive_field_size = shape[2] * shape[3]
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
|
||||
def init_weights(net, init_type='normal', init_gain=0.02):
|
||||
"""
|
||||
Initialize network weights.
|
||||
|
||||
:param net: network to be initialized
|
||||
:type net: nn.Module
|
||||
:param init_type: the name of an initialization method: normal | xavier | kaiming | orthogonal
|
||||
:type init_type: str
|
||||
:param init_gain: scaling factor for normal, xavier and orthogonal.
|
||||
:type init_gain: float
|
||||
"""
|
||||
|
||||
for _, cell in net.cells_and_names():
|
||||
classname = cell.__class__.__name__
|
||||
if hasattr(cell, 'in_proj_layer'):
|
||||
cell.in_proj_layer = Parameter(initializer(HeUniform(negative_slope=math.sqrt(5)), cell.in_proj_layer.shape,
|
||||
cell.in_proj_layer.dtype), name=cell.in_proj_layer.name)
|
||||
if hasattr(cell, 'weight'):
|
||||
if init_type == 'normal':
|
||||
cell.weight = Parameter(initializer(Normal(init_gain), cell.weight.shape,
|
||||
cell.weight.dtype), name=cell.weight.name)
|
||||
elif init_type == 'xavier':
|
||||
cell.weight = Parameter(initializer(XavierUniform(init_gain), cell.weight.shape,
|
||||
cell.weight.dtype), name=cell.weight.name)
|
||||
elif init_type == "he":
|
||||
cell.weight = Parameter(initializer(HeUniform(negative_slope=math.sqrt(5)), cell.weight.shape,
|
||||
cell.weight.dtype), name=cell.weight.name)
|
||||
else:
|
||||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||
|
||||
if hasattr(cell, 'bias') and cell.bias is not None:
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.shape)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
cell.bias = Parameter(initializer(Uniform(bound), cell.bias.shape, cell.bias.dtype),
|
||||
name=cell.bias.name)
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
cell.gamma = Parameter(initializer(Normal(1.0), cell.gamma.default_input.shape()), name=cell.gamma.name)
|
||||
cell.beta = Parameter(initializer(Zero(), cell.beta.default_input.shape()), name=cell.beta.name)
|
||||
|
||||
print('initialize network weight with %s' % init_type)
|
||||
|
||||
def train_net(distribute, imagenet, epochs):
|
||||
"""Train net"""
|
||||
set_seed(1)
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
if imagenet == 1:
|
||||
train_dataset = ImgData(args)
|
||||
else:
|
||||
train_dataset = data.Data(args).loader_train
|
||||
|
||||
if distribute:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
rank_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=rank_size, gradients_mean=True)
|
||||
print('Rank {}, rank_size {}'.format(rank_id, rank_size))
|
||||
if imagenet == 1:
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset,
|
||||
["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
|
||||
num_shards=rank_size, shard_id=args.rank, shuffle=True)
|
||||
else:
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=rank_size,
|
||||
shard_id=rank_id, shuffle=True)
|
||||
else:
|
||||
if imagenet == 1:
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset,
|
||||
["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
|
||||
shuffle=True)
|
||||
else:
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], shuffle=True)
|
||||
|
||||
resize_fuc = bicubic()
|
||||
train_de_dataset = train_de_dataset.project(columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"])
|
||||
train_de_dataset = train_de_dataset.batch(args.batch_size,
|
||||
input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"],
|
||||
output_columns=["LR", "HR", "idx", "filename"],
|
||||
drop_remainder=True, per_batch_map=resize_fuc.forward)
|
||||
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
|
||||
|
||||
net_work = IPT(args)
|
||||
init_weights(net_work, init_type='he', init_gain=1.0)
|
||||
print("Init net weight successfully")
|
||||
if args.pth_path:
|
||||
param_dict = load_checkpoint(args.pth_path)
|
||||
load_param_into_net(net_work, param_dict)
|
||||
print("Load net weight successfully")
|
||||
|
||||
train_func = Trainer(args, train_loader, net_work)
|
||||
for epoch in range(0, epochs):
|
||||
train_func.update_learning_rate(epoch)
|
||||
train_func.train()
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_net(distribute=args.distribute, imagenet=args.imagenet, epochs=args.epochs)
|
|
@ -0,0 +1,95 @@
|
|||
"""train finetune"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.common import set_seed
|
||||
from src.args import args
|
||||
from src.data.imagenet import ImgData
|
||||
from src.data.srdata import SRData
|
||||
from src.data.div2k import DIV2K
|
||||
from src.data.bicubic import bicubic
|
||||
from src.ipt_model import IPT
|
||||
from src.utils import Trainer
|
||||
|
||||
def train_net(distribute, imagenet):
|
||||
"""Train net with finetune"""
|
||||
set_seed(1)
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
if imagenet == 1:
|
||||
train_dataset = ImgData(args)
|
||||
elif not args.derain:
|
||||
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
|
||||
train_dataset.set_scale(args.task_id)
|
||||
else:
|
||||
train_dataset = SRData(args, name=args.data_train, train=True, benchmark=False)
|
||||
train_dataset.set_scale(args.task_id)
|
||||
|
||||
if distribute:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
rank_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=rank_size, gradients_mean=True)
|
||||
print('Rank {}, group_size {}'.format(rank_id, rank_size))
|
||||
if imagenet == 1:
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset,
|
||||
["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
|
||||
num_shards=rank_size, shard_id=rank_id, shuffle=True)
|
||||
else:
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR", "idx", "filename"],
|
||||
num_shards=rank_size, shard_id=rank_id, shuffle=True)
|
||||
else:
|
||||
if imagenet == 1:
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset,
|
||||
["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
|
||||
shuffle=True)
|
||||
else:
|
||||
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR", "idx", "filename"], shuffle=True)
|
||||
|
||||
if args.imagenet == 1:
|
||||
resize_fuc = bicubic()
|
||||
train_de_dataset = train_de_dataset.batch(
|
||||
args.batch_size,
|
||||
input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
|
||||
output_columns=["LR", "HR", "idx", "filename"], drop_remainder=True,
|
||||
per_batch_map=resize_fuc.forward)
|
||||
else:
|
||||
train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
|
||||
|
||||
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
|
||||
net_m = IPT(args)
|
||||
print("Init net weights successfully")
|
||||
|
||||
if args.pth_path:
|
||||
param_dict = load_checkpoint(args.pth_path)
|
||||
load_param_into_net(net_m, param_dict)
|
||||
print("Load net weight successfully")
|
||||
|
||||
train_func = Trainer(args, train_loader, net_m)
|
||||
|
||||
for epoch in range(0, args.epochs):
|
||||
train_func.update_learning_rate(epoch)
|
||||
train_func.train()
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_net(distribute=args.distribute, imagenet=args.imagenet)
|
Loading…
Reference in New Issue