add IPT net

This commit is contained in:
lilei 2021-04-20 21:12:17 +08:00
parent e0ea2767b7
commit 9022e552fd
21 changed files with 1385 additions and 756 deletions

69
model_zoo/research/cv/IPT/eval.py Executable file → Normal file
View File

@ -1,6 +1,6 @@
"""eval script"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@ -13,48 +13,82 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import numpy as np
from src import ipt
import mindspore.dataset as ds
from mindspore import Tensor, context
from mindspore.common import dtype as mstype
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.args import args
import src.ipt_model as ipt
from src.data.srdata import SRData
from src.metrics import calc_psnr, quantize
from mindspore import context
import mindspore.dataset as de
from mindspore.train.serialization import load_checkpoint, load_param_into_net
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(max_call_depth=10000)
context.set_context(mode=context.GRAPH_MODE, device_target="ASCEND", device_id=0)
def sub_mean(x):
red_channel_mean = 0.4488 * 255
green_channel_mean = 0.4371 * 255
blue_channel_mean = 0.4040 * 255
x[:, 0, :, :] -= red_channel_mean
x[:, 1, :, :] -= green_channel_mean
x[:, 2, :, :] -= blue_channel_mean
return x
def add_mean(x):
red_channel_mean = 0.4488 * 255
green_channel_mean = 0.4371 * 255
blue_channel_mean = 0.4040 * 255
x[:, 0, :, :] += red_channel_mean
x[:, 1, :, :] += green_channel_mean
x[:, 2, :, :] += blue_channel_mean
return x
def main():
def eval_net():
"""eval"""
args.batch_size = 128
args.decay = 70
args.patch_size = 48
args.num_queries = 6
args.model = 'vtip'
args.num_layers = 4
if args.epochs == 0:
args.epochs = 1e8
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False
train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
train_de_dataset = de.GeneratorDataset(train_dataset, ['LR', "HR"], shuffle=False)
train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR', "idx", "filename"], shuffle=False)
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
train_loader = train_de_dataset.create_dict_iterator()
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
net_m = ipt.IPT(args)
print('load mindspore net successfully.')
if args.pth_path:
param_dict = load_checkpoint(args.pth_path)
load_param_into_net(net_m, param_dict)
net_m.set_train(False)
idx = Tensor(np.ones(args.task_id), mstype.int32)
inference = ipt.IPT_post(net_m, args)
print('load mindspore net successfully.')
num_imgs = train_de_dataset.get_dataset_size()
psnrs = np.zeros((num_imgs, 1))
inference = ipt.IPT_post(net_m, args)
for batch_idx, imgs in enumerate(train_loader):
lr = imgs['LR']
hr = imgs['HR']
hr_np = np.float32(hr.asnumpy())
pred = inference.forward(lr)
pred_np = np.float32(pred.asnumpy())
lr = sub_mean(lr)
lr = Tensor(lr, mstype.float32)
pred = inference.forward(lr, idx)
pred_np = add_mean(pred.asnumpy())
pred_np = quantize(pred_np, 255)
psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True)
psnr = calc_psnr(pred_np, hr, 4, 255.0)
print("current psnr: ", psnr)
psnrs[batch_idx, 0] = psnr
if args.denoise:
print('Mean psnr of %s DN_%s is %.4f' % (args.data_test[0], args.sigma, psnrs.mean(axis=0)[0]))
@ -63,7 +97,6 @@ def main():
else:
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
if __name__ == '__main__':
print("Start main function!")
main()
print("Start eval function!")
eval_net()

View File

@ -1,26 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.vitm import ViT
def IPT(*args, **kwargs):
return ViT(*args, **kwargs)
def create_network(name, *args, **kwargs):
if name == 'IPT':
return IPT(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

75
model_zoo/research/cv/IPT/readme.md Executable file → Normal file
View File

@ -45,9 +45,9 @@ The result images are converted into YCbCr color space. The PSNR is evaluated on
## Requirements
### Hardware (GPU)
### Hardware (Ascend)
> Prepare hardware environment with GPU.
> Prepare hardware environment with Ascend.
### Framework
@ -67,34 +67,73 @@ The result images are converted into YCbCr color space. The PSNR is evaluated on
```bash
IPT
├── eval.py # inference entry
├── train.py # pre-training entry
├── train_finetune.py # fine-tuning entry
├── image
│   └── ipt.png # the illustration of IPT network
├── model
│   ├── IPT_denoise30.ckpt # denoise model weights for noise level 30
│   ├── IPT_denoise50.ckpt # denoise model weights for noise level 50
│   ├── IPT_derain.ckpt # derain model weights
│   ├── IPT_sr2.ckpt # X2 super-resolution model weights
│   ├── IPT_sr3.ckpt # X3 super-resolution model weights
│   └── IPT_sr4.ckpt # X4 super-resolution model weights
├── readme.md # Readme
├── scripts
│   └── run_eval.sh # inference script for all tasks
│   ├── run_eval.sh # inference script for all tasks
│   ├── run_distributed.sh # pre-training script for all tasks
│   └── run_finetune_distributed.sh # fine-tuning script for all tasks
└── src
├── args.py # options/hyper-parameters of IPT
├── data
│   ├── common.py # common dataset
│   ├── __init__.py # Class data init function
│   └── srdata.py # flow of loading sr data
├── foldunfold_stride.py # function of fold and unfold operations for images
│   ├── bicubic.py # scripts for data pre-processing
│   ├── div2k.py # DIV2K dataset
│   ├── imagenet.py # Imagenet data for pre-training
│   └── srdata.py # All dataset
├── metrics.py # PSNR calculator
├── template.py # setting of model selection
└── vitm.py # IPT network
├── utils.py # training scripts
├── loss.py # contrastive_loss
└── ipt_model.py # IPT network
```
### Script Parameter
> For details about hyperparameters, see src/args.py.
## Training Process
### For pre-training
```bash
python train.py --distribute --imagenet 1 --batch_size 64 --lr 5e-5 --scale 2+3+4+1+1+1 --alltask --react --model vtip --num_queries 6 --chop_new --num_layers 4 --data_train imagenet --dir_data $DATA_PATH --derain --save $SAVE_PATH
```
> Or one can run following script for all tasks.
```bash
sh scripts/run_distributed.sh RANK_TABLE_FILE DATA_PATH
```
### For fine-tuning
> For SR tasks:
```bash
python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --epochs 50
```
> For Denoising tasks:
```bash
python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --denoise --sigma $Noise --epochs 50
```
> For deraining tasks:
```bash
python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --derain --epochs 50
```
> Or one can run following script for all tasks.
```bash
sh scripts/run_finetune_distributed.sh RANK_TABLE_FILE DATA_PATH MODEL TASK_ID
```
## Evaluation
### Evaluation Process
@ -103,13 +142,13 @@ IPT
> For SR x4:
```bash
python eval.py --dir_data ../../data/ --data_test Set14 --nochange --test_only --ext img --chop_new --scale 4 --pth_path ./model/IPT_sr4.ckpt
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale $SCALE
```
> Or one can run following script for all tasks.
```bash
sh scripts/run_eval.sh
sh scripts/run_eval.sh DATA_PATH DATA_TEST MODEL TASK_ID
```
### Evaluation Result
@ -117,7 +156,7 @@ sh scripts/run_eval.sh
The result are evaluated by the value of PSNR (Peak Signal-to-Noise Ratio), and the format is as following.
```bash
result: {"Mean psnr of Se5 x4 is 32.68"}
result: {"Mean psnr of Set5 x4 is 32.68"}
```
## Performance

View File

@ -0,0 +1,40 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
RANK_TABLE_FILE=$(realpath $1)
export RANK_TABLE_FILE
export DATA_PATH=$2
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp -r ../src ./train_parallel$i
cp ../*.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
cd ./train_parallel$i ||exit
env > env$i.log
python train.py --distribute --imagenet 1 --batch_size 64 --lr 5e-5 --scale 2+3+4+1+1+1 --alltask --react --model vtip --num_queries 6 --chop_new --num_layers 4 --data_train imagenet --dir_data $DATA_PATH --derain --save experiments/ckpt_new_init/ > log 2>&1 &
cd ..
done

55
model_zoo/research/cv/IPT/scripts/run_eval.sh Executable file → Normal file
View File

@ -14,18 +14,49 @@
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
DATA_DIR=$2
DATA_SET=$3
PATH_CHECKPOINT=$4
ulimit -u unlimited
export DATA_PATH=$1
export DATA_TEST=$2
export MODEL=$3
export TASK_ID=$4
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 4 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 3 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 2 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
if [[ $TASK_ID -lt 3 ]]; then
mkdir ./run_eval$TASK_ID
cp -r ../src ./run_eval$TASK_ID
cp ../*.py ./run_eval$TASK_ID
echo "start evaluation for Task $TASK_ID, device $DEVICE_ID"
cd ./run_eval$TASK_ID ||exit
env > env$TASK_ID.log
SCALE=$[$TASK_ID+2]
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale $SCALE > log 2>&1 &
fi
##denoise
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 30 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 50 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
if [[ $TASK_ID -eq 3 ]]; then
mkdir ./run_eval$TASK_ID
cp -r ../src ./run_eval$TASK_ID
cp ../*.py ./run_eval$TASK_ID
echo "start evaluation for Task $TASK_ID, device $DEVICE_ID"
cd ./run_eval$TASK_ID ||exit
env > env$TASK_ID.log
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale 1 --derain > log 2>&1 &
fi
##derain
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --derain --derain_test 1 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
if [[ $TASK_ID -eq 4 ]]; then
mkdir ./run_eval$TASK_ID
cp -r ../src ./run_eval$TASK_ID
cp ../*.py ./run_eval$TASK_ID
echo "start evaluation for Task $TASK_ID, device $DEVICE_ID"
cd ./run_eval$TASK_ID ||exit
env > env$TASK_ID.log
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale 1 --denoise --sigma 30 > log 2>&1 &
fi
if [[ $TASK_ID -eq 5 ]]; then
mkdir ./run_eval$TASK_ID
cp -r ../src ./run_eval$TASK_ID
cp ../*.py ./run_eval$TASK_ID
echo "start evaluation for Task $TASK_ID, device $DEVICE_ID"
cd ./run_eval$TASK_ID ||exit
env > env$TASK_ID.log
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale 1 --denoise --sigma 50 > log 2>&1 &
fi

View File

@ -0,0 +1,43 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
RANK_TABLE_FILE=$(realpath $1)
export RANK_TABLE_FILE
export DATA_PATH=$2
export MODEL=$3
export TASK_ID=$4
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp -r ../src ./train_parallel$i
cp ../*.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
cd ./train_parallel$i ||exit
env > env$i.log
python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --epochs 100 > log 2>&1 &
cd ..
done

24
model_zoo/research/cv/IPT/src/args.py Executable file → Normal file
View File

@ -1,4 +1,4 @@
'''args'''
"""args"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
from src import template
parser = argparse.ArgumentParser(description='EDSR and MDSR')
@ -24,12 +24,6 @@ parser.add_argument('--template', default='.',
help='You can set various templates in option.py')
# Hardware specifications
parser.add_argument('--n_threads', type=int, default=6,
help='number of threads for data loading')
parser.add_argument('--cpu', action='store_true',
help='use cpu only')
parser.add_argument('--n_GPUs', type=int, default=1,
help='number of GPUs')
parser.add_argument('--seed', type=int, default=1,
help='random seed')
@ -60,9 +54,8 @@ parser.add_argument('--no_augment', action='store_true',
help='do not use data augmentation')
# Model specifications
parser.add_argument('--model', default='vtip',
parser.add_argument('--model', default='EDSR',
help='model name')
parser.add_argument('--act', type=str, default='relu',
help='activation function')
parser.add_argument('--pre_train', type=str, default='',
@ -139,6 +132,7 @@ parser.add_argument('--gclip', type=float, default=0,
help='gradient clipping threshold (0 = no clipping)')
# Loss specifications
parser.add_argument('--con_loss', action='store_true')
parser.add_argument('--loss', type=str, default='1*L1',
help='loss function configuration')
parser.add_argument('--skip_threshold', type=float, default='1e8',
@ -161,6 +155,7 @@ parser.add_argument('--save_gt', action='store_true',
help='save low-resolution and high-resolution images together')
parser.add_argument('--scalelr', type=int, default=0)
# cloud
parser.add_argument('--moxfile', type=int, default=1)
parser.add_argument('--imagenet', type=int, default=0)
@ -169,10 +164,11 @@ parser.add_argument('--train_url', type=str, help='train_dir')
parser.add_argument('--pretrain', type=str, default='')
parser.add_argument('--pth_path', type=str, default='')
parser.add_argument('--load_query', type=int, default=0)
# transformer
parser.add_argument('--patch_dim', type=int, default=3)
parser.add_argument('--num_heads', type=int, default=12)
parser.add_argument('--num_layers', type=int, default=12)
parser.add_argument('--num_layers', type=int, default=4)
parser.add_argument('--dropout_rate', type=float, default=0)
parser.add_argument('--no_norm', action='store_true')
parser.add_argument('--post_norm', action='store_true')
@ -192,8 +188,10 @@ parser.add_argument('--sigma', type=float, default=25)
parser.add_argument('--derain', action='store_true')
parser.add_argument('--finetune', action='store_true')
parser.add_argument('--derain_test', type=int, default=10)
# alltask
parser.add_argument('--alltask', action='store_true')
parser.add_argument('--task_id', type=int, default=0)
# dehaze
parser.add_argument('--dehaze', action='store_true')
@ -201,6 +199,7 @@ parser.add_argument('--dehaze_test', type=int, default=100)
parser.add_argument('--indoor', action='store_true')
parser.add_argument('--outdoor', action='store_true')
parser.add_argument('--nochange', action='store_true')
# deblur
parser.add_argument('--deblur', action='store_true')
parser.add_argument('--deblur_test', type=int, default=1000)
@ -210,6 +209,8 @@ parser.add_argument('--init_method', type=str,
default=None, help='master address')
parser.add_argument('--rank', type=int, default=0,
help='Index of current task')
parser.add_argument('--group_size', type=int, default=0,
help='Index of current task')
parser.add_argument('--world_size', type=int, default=1,
help='Total number of tasks')
parser.add_argument('--gpu', default=None, type=int,
@ -223,7 +224,6 @@ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
parser.add_argument('--distribute', action='store_true')
args, unparsed = parser.parse_known_args()
template.set_template(args)
args.scale = [int(x) for x in args.scale.split("+")]
args.data_train = args.data_train.split('+')

View File

@ -1,35 +0,0 @@
"""data"""
# Copyright 2021 Huawei Technologies Co., Ltd"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from importlib import import_module
class Data:
"""data"""
def __init__(self, args):
self.loader_train = None
self.loader_test = []
for d in args.data_test:
if d in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109', 'CBSD68', 'Rain100L', 'GOPRO_Large']:
m = import_module('data.benchmark')
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
else:
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
m = import_module('data.' + module_name.lower())
testset = getattr(m, module_name)(args, train=False, name=d)
self.loader_test.append(
testset
)

View File

@ -0,0 +1,132 @@
"""bicubic"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
class bicubic:
"""bicubic"""
def __init__(self, seed=0):
self.seed = seed
self.rand_fn = np.random.RandomState(self.seed)
def cubic(self, x):
absx2 = np.abs(x) * np.abs(x)
absx3 = np.abs(x) * np.abs(x) * np.abs(x)
condition1 = (np.abs(x) <= 1).astype(np.float32)
condition2 = ((np.abs(x) > 1) & (np.abs(x) <= 2)).astype(np.float32)
f = (1.5 * absx3 - 2.5 * absx2 + 1) * condition1 + (-0.5 * absx3 + 2.5 * absx2 - 4 * np.abs(x) + 2) * condition2
return f
def contribute(self, in_size, out_size, scale):
"""bicubic"""
kernel_width = 4
if scale < 1:
kernel_width = 4 / scale
x0 = np.arange(start=1, stop=out_size[0]+1).astype(np.float32)
x1 = np.arange(start=1, stop=out_size[1]+1).astype(np.float32)
u0 = x0 / scale + 0.5 * (1 - 1 / scale)
u1 = x1 / scale + 0.5 * (1 - 1 / scale)
left0 = np.floor(u0 - kernel_width / 2)
left1 = np.floor(u1 - kernel_width / 2)
width = np.ceil(kernel_width) + 2
indice0 = np.expand_dims(left0, axis=1) + \
np.expand_dims(np.arange(start=0, stop=width).astype(np.float32), axis=0)
indice1 = np.expand_dims(left1, axis=1) + \
np.expand_dims(np.arange(start=0, stop=width).astype(np.float32), axis=0)
mid0 = np.expand_dims(u0, axis=1) - np.expand_dims(indice0, axis=0)
mid1 = np.expand_dims(u1, axis=1) - np.expand_dims(indice1, axis=0)
if scale < 1:
weight0 = scale * self.cubic(mid0 * scale)
weight1 = scale * self.cubic(mid1 * scale)
else:
weight0 = self.cubic(mid0)
weight1 = self.cubic(mid1)
weight0 = weight0 / (np.expand_dims(np.sum(weight0, axis=2), 2))
weight1 = weight1 / (np.expand_dims(np.sum(weight1, axis=2), 2))
indice0 = np.expand_dims(np.minimum(np.maximum(1, indice0), in_size[0]), axis=0)
indice1 = np.expand_dims(np.minimum(np.maximum(1, indice1), in_size[1]), axis=0)
kill0 = np.equal(weight0, 0)[0][0]
kill1 = np.equal(weight1, 0)[0][0]
weight0 = weight0[:, :, kill0 == 0]
weight1 = weight1[:, :, kill1 == 0]
indice0 = indice0[:, :, kill0 == 0]
indice1 = indice1[:, :, kill1 == 0]
return weight0, weight1, indice0, indice1
def forward(self, hr, rain, lrx2, lrx3, lrx4, filename, batchInfo):
"""bicubic"""
idx = self.rand_fn.randint(0, 6)
if idx < 3:
if idx == 0:
scale = 1/2
hr = lrx2
elif idx == 1:
scale = 1/3
hr = lrx3
elif idx == 2:
scale = 1/4
hr = lrx4
hr = np.array(hr)
[_, _, h, w] = hr.shape
weight0, weight1, indice0, indice1 = self.contribute([h, w], [int(h * scale), int(w * scale)], scale)
weight0 = np.asarray(weight0[0], dtype=np.float32)
indice0 = np.asarray(indice0[0], dtype=np.float32).astype(np.long)
weight0 = np.expand_dims(np.expand_dims(np.expand_dims(weight0, axis=0), axis=1), axis=4)
out = hr[:, :, (indice0-1), :] * weight0
out = np.sum(out, axis=3)
A = np.transpose(out, (0, 1, 3, 2))
weight1 = np.asarray(weight1[0], dtype=np.float32)
weight1 = np.expand_dims(np.expand_dims(np.expand_dims(weight1, axis=0), axis=1), axis=4)
indice1 = np.asarray(indice1[0], dtype=np.float32).astype(np.long)
out = A[:, :, (indice1-1), :] * weight1
out = np.round(255 * np.transpose(np.sum(out, axis=3), (0, 1, 3, 2)))/255
out = np.clip(np.round(out), 0, 255)
lr = list(out)
hr = list(hr)
else:
if idx == 3:
hr = np.array(hr)
rain = np.array(rain)
lr = np.clip((rain + hr), 0, 255)
hr = list(hr)
lr = list(lr)
elif idx == 4:
hr = np.array(hr)
noise = np.random.randn(*hr.shape) * 30
lr = np.clip(noise + hr, 0, 255)
hr = list(hr)
lr = list(lr)
elif idx == 5:
hr = np.array(hr)
noise = np.random.randn(*hr.shape) * 50
lr = np.clip(noise + hr, 0, 255)
hr = list(hr)
lr = list(lr)
return lr, hr, [idx] * len(hr), filename

43
model_zoo/research/cv/IPT/src/data/common.py Executable file → Normal file
View File

@ -1,6 +1,6 @@
"""common"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@ -13,13 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import random
import numpy as np
import skimage.color as sc
def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
def get_patch(*args, patch_size=96, scale=2, input_large=False):
"""common"""
ih, iw = args[0].shape[:2]
@ -34,25 +32,19 @@ def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
else:
tx, ty = ix, iy
ret = [
args[0][iy:iy + ip, ix:ix + ip, :],
*[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
]
ret = [args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]]
return ret
def set_channel(*args, n_channels=3):
"""common"""
def _set_channel(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
c = img.shape[2]
if n_channels == 1 and c == 3:
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
elif n_channels == 3 and c == 1:
if n_channels == 3 and c == 1:
img = np.concatenate([img] * n_channels, 2)
return img[:, :, :n_channels]
@ -61,14 +53,11 @@ def set_channel(*args, n_channels=3):
def np2Tensor(*args, rgb_range=255):
"""common"""
def _np2Tensor(img):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
tensor = np_transpose.astype(np.float32)
tensor = tensor * (rgb_range / 255)
return tensor
input_data = np_transpose.astype(np.float32)
output = input_data * (rgb_range / 255)
return output
return [_np2Tensor(a) for a in args]
@ -79,6 +68,7 @@ def augment(*args, hflip=True, rot=True):
rot90 = rot and random.random() < 0.5
def _augment(img):
"""common"""
if hflip:
img = img[:, ::-1, :]
if vflip:
@ -88,3 +78,18 @@ def augment(*args, hflip=True, rot=True):
return img
return [_augment(a) for a in args]
def search(root, target="JPEG"):
"""srdata"""
item_list = []
items = os.listdir(root)
for item in items:
path = os.path.join(root, item)
if os.path.isdir(path):
item_list.extend(search(path, target))
elif path.split('/')[-1].startswith(target):
item_list.append(path)
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
item_list.append(path)
return item_list

View File

@ -0,0 +1,45 @@
"""div2k"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
from src.data.srdata import SRData
class DIV2K(SRData):
"""DIV2K"""
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
data_range = [r.split('-') for r in args.data_range.split('/')]
if train:
data_range = data_range[0]
else:
if args.test_only and len(data_range) == 1:
data_range = data_range[0]
else:
data_range = data_range[1]
self.begin, self.end = list(map(int, data_range))
super(DIV2K, self).__init__(args, name=name, train=train, benchmark=benchmark)
def _scan(self):
names_hr, names_lr = super(DIV2K, self)._scan()
names_hr = names_hr[self.begin - 1:self.end]
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
return names_hr, names_lr
def _set_filesystem(self, dir_data):
super(DIV2K, self)._set_filesystem(dir_data)
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')

View File

@ -0,0 +1,171 @@
"""imagent"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import random
import io
from PIL import Image
import numpy as np
import imageio
def search(root, target="JPEG"):
"""imagent"""
item_list = []
items = os.listdir(root)
for item in items:
path = os.path.join(root, item)
if os.path.isdir(path):
item_list.extend(search(path, target))
elif path.split('.')[-1] == target:
item_list.append(path)
elif path.split('/')[-1].startswith(target):
item_list.append(path)
return item_list
def get_patch_img(img, patch_size=96, scale=2):
"""imagent"""
ih, iw = img.shape[:2]
tp = scale * patch_size
if (iw - tp) > -1 and (ih-tp) > 1:
ix = random.randrange(0, iw-tp+1)
iy = random.randrange(0, ih-tp+1)
hr = img[iy:iy+tp, ix:ix+tp, :3]
elif (iw - tp) > -1 and (ih - tp) <= -1:
ix = random.randrange(0, iw-tp+1)
hr = img[:, ix:ix+tp, :3]
pil_img = Image.fromarray(hr).resize((tp, tp), Image.BILINEAR)
hr = np.array(pil_img)
elif (iw - tp) <= -1 and (ih - tp) > -1:
iy = random.randrange(0, ih-tp+1)
hr = img[iy:iy+tp, :, :3]
pil_img = Image.fromarray(hr).resize((tp, tp), Image.BILINEAR)
hr = np.array(pil_img)
else:
pil_img = Image.fromarray(img).resize((tp, tp), Image.BILINEAR)
hr = np.array(pil_img)
return hr
class ImgData():
"""imagent"""
def __init__(self, args, train=True):
self.input_large = (args.model == 'VDSR')
self.scale = args.scale
self.idx_scale = 0
self.dataroot = args.dir_data
self.img_list = search(os.path.join(self.dataroot, "train"), "JPEG")
self.img_list.extend(search(os.path.join(self.dataroot, "val"), "JPEG"))
self.img_list = sorted(self.img_list)
self.train = train
self.args = args
self.len = len(self.img_list)
print("data length:", len(self.img_list))
if self.args.derain:
self.derain_dataroot = os.path.join(self.dataroot, "RainTrainL")
self.derain_img_list = search(self.derain_dataroot, "rainstreak")
def __len__(self):
return len(self.img_list)
def _get_index(self, idx):
return idx % len(self.img_list)
def _load_file(self, idx):
idx = self._get_index(idx)
f_lr = self.img_list[idx]
lr = imageio.imread(f_lr)
if len(lr.shape) == 2:
lr = np.dstack([lr, lr, lr])
return lr, f_lr
def _np2Tensor(self, img, rgb_range):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
tensor = np_transpose.astype(np.float32)
tensor = tensor * (rgb_range / 255)
return tensor
def __getitem__(self, idx):
if self.args.model == 'vtip' and self.train and self.args.alltask:
lr, filename = self._load_file(idx % self.len)
pair_list = []
rain = self._load_rain()
rain = np.expand_dims(rain, axis=2)
rain = self.get_patch(rain, 1)
rain = self._np2Tensor(rain, rgb_range=self.args.rgb_range)
for idx_scale in range(4):
self.idx_scale = idx_scale
pair = self.get_patch(lr)
pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range)
pair_list.append(pair_t)
return pair_list[3], rain, pair_list[0], pair_list[1], pair_list[2], [self.scale], [filename]
if self.args.model == 'vtip' and self.train and len(self.scale) > 1:
lr, filename = self._load_file(idx % self.len)
pair_list = []
for idx_scale in range(3):
self.idx_scale = idx_scale
pair = self.get_patch(lr)
pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range)
pair_list.append(pair_t)
return pair_list[0], pair_list[1], pair_list[2], filename
if self.args.model == 'vtip' and self.args.derain and self.scale[self.idx_scale] == 1:
lr, filename = self._load_file(idx % self.len)
rain = self._load_rain()
rain = np.expand_dims(rain, axis=2)
rain = self.get_patch(rain, 1)
rain = self._np2Tensor(rain, rgb_range=self.args.rgb_range)
pair = self.get_patch(lr)
pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range)
return pair_t, rain, filename
if self.args.jpeg:
hr, filename = self._load_file(idx % self.len)
buffer = io.BytesIO()
width, height = hr.size
patch_size = self.scale[self.idx_scale]*self.args.patch_size
if width < patch_size:
hr = hr.resize((patch_size, height), Image.ANTIALIAS)
width, height = hr.size
if height < patch_size:
hr = hr.resize((width, patch_size), Image.ANTIALIAS)
hr.save(buffer, format='jpeg', quality=25)
lr = Image.open(buffer)
lr = np.array(lr).astype(np.float32)
hr = np.array(hr).astype(np.float32)
lr = self.get_patch(lr)
hr = self.get_patch(hr)
lr = self._np2Tensor(lr, rgb_range=self.args.rgb_range)
hr = self._np2Tensor(hr, rgb_range=self.args.rgb_range)
return lr, hr, filename
lr, filename = self._load_file(idx % self.len)
pair = self.get_patch(lr)
pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range)
return pair_t, filename
def _load_rain(self):
idx = random.randint(0, len(self.derain_img_list) - 1)
f_lr = self.derain_img_list[idx]
lr = imageio.imread(f_lr)
return lr
def get_patch(self, lr, scale=0):
if scale == 0:
scale = self.scale[self.idx_scale]
lr = get_patch_img(lr, patch_size=self.args.patch_size, scale=scale)
return lr
def set_scale(self, idx_scale):
self.idx_scale = idx_scale

170
model_zoo/research/cv/IPT/src/data/srdata.py Executable file → Normal file
View File

@ -1,6 +1,6 @@
"""srdata"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import glob
import random
@ -20,43 +21,12 @@ import pickle
import numpy as np
import imageio
from src.data import common
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def search(root, target="JPEG"):
class SRData:
"""srdata"""
item_list = []
items = os.listdir(root)
for item in items:
path = os.path.join(root, item)
if os.path.isdir(path):
item_list.extend(search(path, target))
elif path.split('/')[-1].startswith(target):
item_list.append(path)
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
item_list.append(path)
else:
item_list = []
return item_list
def search_dehaze(root, target="JPEG"):
"""srdata"""
item_list = []
items = os.listdir(root)
for item in items:
path = os.path.join(root, item)
if os.path.isdir(path):
extend_list = search_dehaze(path, target)
if extend_list is not None:
item_list.extend(extend_list)
elif path.split('/')[-2].endswith(target):
item_list.append(path)
return item_list
class SRData():
"""srdata"""
def __init__(self, args, name='', train=True, benchmark=False):
self.args = args
self.name = name
@ -69,37 +39,46 @@ class SRData():
self.idx_scale = 0
if self.args.derain:
self.derain_test = os.path.join(args.dir_data, "Rain100L")
self.derain_lr_test = search(self.derain_test, "rain")
self.derain_hr_test = [path.replace(
"rainy/", "no") for path in self.derain_lr_test]
if self.train:
self.derain_dataroot = os.path.join(args.dir_data, "RainTrainL")
self.clear_train = common.search(self.derain_dataroot, "norain")
self.rain_train = []
for path in self.clear_train:
change_path = path.split('/')
change_path[-1] = change_path[-1][2:]
self.rain_train.append('/'.join(change_path))
self.derain_test = os.path.join(args.dir_data, "Rain100L")
self.deblur_lr_test = common.search(self.derain_test, "rain")
self.deblur_hr_test = [path.replace("rainy/", "no") for path in self.deblur_lr_test]
self.derain_hr_test = self.deblur_hr_test
else:
self.derain_test = os.path.join(args.dir_data, "Rain100L")
self.derain_lr_test = common.search(self.derain_test, "rain")
self.derain_hr_test = [path.replace("rainy/", "no") for path in self.derain_lr_test]
self._set_filesystem(args.dir_data)
self._set_img(args)
if self.args.derain and self.train:
self.images_hr, self.images_lr = self.clear_train, self.rain_train
if train:
self._repeat(args)
def _set_img(self, args):
"""srdata"""
if args.ext.find('img') < 0:
path_bin = os.path.join(self.apath, 'bin')
os.makedirs(path_bin, exist_ok=True)
list_hr, list_lr = self._scan()
if args.ext.find('img') >= 0 or benchmark:
if args.ext.find('img') >= 0 or self.benchmark:
self.images_hr, self.images_lr = list_hr, list_lr
elif args.ext.find('sep') >= 0:
os.makedirs(
self.dir_hr.replace(self.apath, path_bin),
exist_ok=True
)
os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True)
for s in self.scale:
if s == 1:
os.makedirs(
os.path.join(self.dir_hr),
exist_ok=True
)
os.makedirs(os.path.join(self.dir_hr), exist_ok=True)
else:
os.makedirs(
os.path.join(
self.dir_lr.replace(self.apath, path_bin),
'X{}'.format(s)
),
exist_ok=True
)
os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True)
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
for h in list_hr:
@ -114,23 +93,27 @@ class SRData():
self.images_lr[i].append(b)
self._check_and_load(args.ext, l, b, verbose=True)
# Below functions as used to prepare images
def _repeat(self, args):
"""srdata"""
n_patches = args.batch_size * args.test_every
n_images = len(args.data_train) * len(self.images_hr)
if n_images == 0:
self.repeat = 0
else:
self.repeat = max(n_patches // n_images, 1)
def _scan(self):
"""srdata"""
names_hr = sorted(
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
)
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])))
names_lr = [[] for _ in self.scale]
for f in names_hr:
filename, _ = os.path.splitext(os.path.basename(f))
for si, s in enumerate(self.scale):
if s != 1:
scale = s
names_lr[si].append(os.path.join(
self.dir_lr, 'X{}/{}x{}{}'.format(
s, filename, scale, self.ext[1]
)
))
names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \
.format(s, filename, scale, self.ext[1])))
for si, s in enumerate(self.scale):
if s == 1:
names_lr[si] = names_hr
@ -150,28 +133,33 @@ class SRData():
pickle.dump(imageio.imread(img), _f)
def __getitem__(self, idx):
if self.args.model == 'vtip' and self.args.derain and self.scale[
self.idx_scale] == 1 and not self.args.finetune:
norain, rain, _ = self._load_rain_test(idx)
pair = common.set_channel(
*[rain, norain], n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
return pair_t[0], pair_t[1]
if self.args.model == 'vtip' and self.args.denoise and self.scale[self.idx_scale] == 1:
hr, _ = self._load_file_hr(idx)
if self.args.derain and self.scale[self.idx_scale] == 1:
if self.train:
lr, hr, filename = self._load_file_deblur(idx)
pair = self.get_patch(lr, hr)
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
else:
norain, rain, filename = self._load_rain_test(idx)
pair = common.set_channel(*[rain, norain], n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
return pair_t[0], pair_t[1], [self.idx_scale], [filename]
if self.args.denoise and self.scale[self.idx_scale] == 1:
hr, filename = self._load_file_hr(idx)
pair = self.get_patch_hr(hr)
pair = common.set_channel(*[pair], n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
noise = np.random.randn(*pair_t[0].shape) * self.args.sigma
lr = pair_t[0] + noise
lr = np.float32(np.clip(lr, 0, 255))
return lr, pair_t[0]
lr, hr, _ = self._load_file(idx)
return lr, pair_t[0], [self.idx_scale], [filename]
lr, hr, filename = self._load_file(idx)
pair = self.get_patch(lr, hr)
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
return pair_t[0], pair_t[1]
return pair_t[0], pair_t[1], [self.idx_scale], [filename]
def __len__(self):
if self.train:
@ -182,7 +170,6 @@ class SRData():
return len(self.images_hr)
def _get_index(self, idx):
"""srdata"""
if self.train:
return idx % len(self.images_hr)
return idx
@ -198,22 +185,9 @@ class SRData():
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f:
hr = pickle.load(_f)
return hr, filename
def _load_rain(self, idx, rain_img=False):
"""srdata"""
idx = random.randint(0, len(self.derain_img_list) - 1)
f_lr = self.derain_img_list[idx]
if rain_img:
norain = imageio.imread(f_lr.replace("rainstreak", "norain"))
rain = imageio.imread(f_lr.replace("rainstreak", "rain"))
return norain, rain
lr = imageio.imread(f_lr)
return lr
def _load_rain_test(self, idx):
"""srdata"""
f_hr = self.derain_hr_test[idx]
f_lr = self.derain_lr_test[idx]
filename, _ = os.path.splitext(os.path.basename(f_lr))
@ -221,14 +195,6 @@ class SRData():
rain = imageio.imread(f_lr)
return norain, rain, filename
def _load_denoise(self, idx):
"""srdata"""
idx = self._get_index(idx)
f_lr = self.images_hr[idx]
norain = imageio.imread(f_lr)
rain = imageio.imread(f_lr.replace("HR", "LR_bicubic"))
return norain, rain
def _load_file(self, idx):
"""srdata"""
idx = self._get_index(idx)
@ -251,12 +217,7 @@ class SRData():
def get_patch_hr(self, hr):
"""srdata"""
if self.train:
hr = self.get_patch_img_hr(
hr,
patch_size=self.args.patch_size,
scale=1
)
hr = self.get_patch_img_hr(hr, patch_size=self.args.patch_size, scale=1)
return hr
def get_patch_img_hr(self, img, patch_size=96, scale=2):
@ -280,9 +241,7 @@ class SRData():
lr, hr = common.get_patch(
lr, hr,
patch_size=self.args.patch_size * scale,
scale=scale,
multi=(len(self.scale) > 1)
)
scale=scale)
if not self.args.no_augment:
lr, hr = common.augment(lr, hr)
else:
@ -292,7 +251,6 @@ class SRData():
return lr, hr
def set_scale(self, idx_scale):
"""srdata"""
if not self.input_large:
self.idx_scale = idx_scale
else:

View File

@ -1,241 +0,0 @@
'''stride'''
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
class _stride_unfold_(nn.Cell):
'''stride'''
def __init__(self,
kernel_size,
stride=-1):
super(_stride_unfold_, self).__init__()
if stride == -1:
self.stride = kernel_size
else:
self.stride = stride
self.kernel_size = kernel_size
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.unfold = _unfold_(kernel_size)
def construct(self, x):
"""stride"""
N, C, H, W = x.shape
leftup_idx_x = []
leftup_idx_y = []
nh = int(H / self.stride)
nw = int(W / self.stride)
for i in range(nh):
leftup_idx_x.append(i * self.stride)
for i in range(nw):
leftup_idx_y.append(i * self.stride)
NumBlock_x = len(leftup_idx_x)
NumBlock_y = len(leftup_idx_y)
zeroslike = P.ZerosLike()
cc_2 = P.Concat(axis=2)
cc_3 = P.Concat(axis=3)
unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size,
NumBlock_y * self.kernel_size), mstype.float32)
N, C, H, W = unf_x.shape
for i in range(NumBlock_x):
for j in range(NumBlock_y):
unf_i = i * self.kernel_size
unf_j = j * self.kernel_size
org_i = leftup_idx_x[i]
org_j = leftup_idx_y[j]
fills = x[:, :, org_i:org_i + self.kernel_size,
org_j:org_j + self.kernel_size]
unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]), cc_2((cc_2(
(zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)), zeroslike(
unf_x[:, :, unf_i + self.kernel_size:, unf_j:unf_j + self.kernel_size]))))),
zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:])))
y = self.unfold(unf_x)
return y
class _stride_fold_(nn.Cell):
'''stride'''
def __init__(self,
kernel_size,
output_shape=(-1, -1),
stride=-1):
super(_stride_fold_, self).__init__()
if isinstance(kernel_size, (list, tuple)):
self.kernel_size = kernel_size
else:
self.kernel_size = [kernel_size, kernel_size]
if stride == -1:
self.stride = kernel_size[0]
else:
self.stride = stride
self.output_shape = output_shape
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.fold = _fold_(kernel_size)
def construct(self, x):
'''stride'''
if self.output_shape[0] == -1:
large_x = self.fold(x)
N, C, H, _ = large_x.shape
leftup_idx = []
for i in range(0, H, self.kernel_size[0]):
leftup_idx.append(i)
NumBlock = len(leftup_idx)
fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0],
(NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32)
for i in range(NumBlock):
for j in range(NumBlock):
fold_i = i * self.stride
fold_j = j * self.stride
org_i = leftup_idx[i]
org_j = leftup_idx[j]
fills = x[:, :, org_i:org_i + self.kernel_size[0],
org_j:org_j + self.kernel_size[1]]
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2(
(zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(
fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))),
zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:])))
y = fold_x
else:
NumBlock_x = int(
(self.output_shape[0] - self.kernel_size[0]) / self.stride + 1)
NumBlock_y = int(
(self.output_shape[1] - self.kernel_size[1]) / self.stride + 1)
large_shape = [NumBlock_x * self.kernel_size[0],
NumBlock_y * self.kernel_size[1]]
self.fold = _fold_(self.kernel_size, large_shape)
large_x = self.fold(x)
N, C, H, _ = large_x.shape
leftup_idx_x = []
leftup_idx_y = []
for i in range(NumBlock_x):
leftup_idx_x.append(i * self.kernel_size[0])
for i in range(NumBlock_y):
leftup_idx_y.append(i * self.kernel_size[1])
fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0],
(NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32)
for i in range(NumBlock_x):
for j in range(NumBlock_y):
fold_i = i * self.stride
fold_j = j * self.stride
org_i = leftup_idx_x[i]
org_j = leftup_idx_y[j]
fills = x[:, :, org_i:org_i + self.kernel_size[0],
org_j:org_j + self.kernel_size[1]]
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2(
(zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(
fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))),
zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:])))
y = fold_x
return y
class _unfold_(nn.Cell):
'''stride'''
def __init__(self,
kernel_size,
stride=-1):
super(_unfold_, self).__init__()
if stride == -1:
self.stride = kernel_size
self.kernel_size = kernel_size
self.reshape = P.Reshape()
self.transpose = P.Transpose()
def construct(self, x):
'''stride'''
N, C, H, W = x.shape
numH = int(H / self.kernel_size)
numW = int(W / self.kernel_size)
if numH * self.kernel_size != H or numW * self.kernel_size != W:
x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size]
output_img = self.reshape(x, (N, C, numH, self.kernel_size, W))
output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
output_img = self.reshape(output_img, (N, C, int(
numH * numW), self.kernel_size, self.kernel_size))
output_img = self.transpose(output_img, (0, 2, 1, 4, 3))
output_img = self.reshape(output_img, (N, int(numH * numW), -1))
return output_img
class _fold_(nn.Cell):
'''stride'''
def __init__(self,
kernel_size,
output_shape=(-1, -1),
stride=-1):
super(_fold_, self).__init__()
if isinstance(kernel_size, (list, tuple)):
self.kernel_size = kernel_size
else:
self.kernel_size = [kernel_size, kernel_size]
if stride == -1:
self.stride = kernel_size[0]
self.output_shape = output_shape
self.reshape = P.Reshape()
self.transpose = P.Transpose()
def construct(self, x):
'''stride'''
N, C, L = x.shape
org_C = int(L / self.kernel_size[0] / self.kernel_size[1])
if self.output_shape[0] == -1:
numH = int(np.sqrt(C))
numW = int(np.sqrt(C))
org_H = int(numH * self.kernel_size[0])
org_W = org_H
else:
org_H = int(self.output_shape[0])
org_W = int(self.output_shape[1])
numH = int(org_H / self.kernel_size[0])
numW = int(org_W / self.kernel_size[1])
output_img = self.reshape(
x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))
output_img = self.transpose(output_img, (0, 2, 1, 3, 4))
output_img = self.reshape(
output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1]))
output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5))
output_img = self.reshape(output_img, (N, org_C, org_H, org_W))
return output_img

View File

@ -13,15 +13,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import math
import copy
import numpy as np
from mindspore import nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common import Tensor, Parameter
class LayerPreprocess(nn.Cell):
"""
Preprocess input of each layer
"""
def __init__(self, in_channels=None):
super(LayerPreprocess, self).__init__()
self.layernorm = nn.LayerNorm((in_channels,))
self.cast = P.Cast()
self.get_dtype = P.DType()
def construct(self, input_tensor):
output = self.cast(input_tensor, mstype.float32)
output = self.layernorm(output)
output = self.cast(output, self.get_dtype(input_tensor))
return output
class MultiheadAttention(nn.Cell):
"""
@ -45,7 +60,7 @@ class MultiheadAttention(nn.Cell):
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
tensor. Default: False.
compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32.
compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float16.
"""
def __init__(self,
@ -64,13 +79,12 @@ class MultiheadAttention(nn.Cell):
use_one_hot_embeddings=False,
initializer_range=0.02,
do_return_2d_tensor=False,
compute_type=mstype.float32,
compute_type=mstype.float16,
same_dim=True):
super(MultiheadAttention, self).__init__()
self.num_attention_heads = num_attention_heads
self.size_per_head = int(hidden_width / num_attention_heads)
self.has_attention_mask = has_attention_mask
assert has_attention_mask
self.use_one_hot_embeddings = use_one_hot_embeddings
self.initializer_range = initializer_range
self.do_return_2d_tensor = do_return_2d_tensor
@ -83,11 +97,9 @@ class MultiheadAttention(nn.Cell):
self.shape_k_2d = (-1, k_tensor_width)
self.shape_v_2d = (-1, v_tensor_width)
self.hidden_width = int(hidden_width)
# units = num_attention_heads * self.size_per_head
if self.same_dim:
self.in_proj_layer = \
Parameter(Tensor(np.random.rand(hidden_width * 3,
q_tensor_width), dtype=compute_type), name="weight")
self.in_proj_layer = Parameter(Tensor(np.random.rand(hidden_width * 3,
q_tensor_width), dtype=mstype.float32), name="weight")
else:
self.query_layer = nn.Dense(q_tensor_width,
hidden_width,
@ -132,8 +144,10 @@ class MultiheadAttention(nn.Cell):
self.equal = P.Equal()
self.shape = P.Shape()
def construct(self, tensor_q, tensor_k, tensor_v, attention_mask=None):
"""Apply multihead attention."""
def construct(self, tensor_q, tensor_k, tensor_v):
"""
Apply multihead attention.
"""
batch_size, seq_length, _ = self.shape(tensor_q)
shape_qkv = (batch_size, -1,
self.num_attention_heads, self.size_per_head)
@ -161,20 +175,14 @@ class MultiheadAttention(nn.Cell):
_start = 0
_end = self.hidden_width
_w = self.in_proj_layer[_start:_end, :]
# _b = None
query_out = self.matmul_dense(_w, tensor_q_2d)
_start = self.hidden_width
_end = self.hidden_width * 2
_w = self.in_proj_layer[_start:_end, :]
# _b = None
key_out = self.matmul_dense(_w, tensor_k_2d)
_start = self.hidden_width * 2
_end = None
_w = self.in_proj_layer[_start:]
# _b = None
value_out = self.matmul_dense(_w, tensor_v_2d)
else:
query_out = self.query_layer(tensor_q_2d)
@ -193,8 +201,7 @@ class MultiheadAttention(nn.Cell):
attention_scores = self.softmax_cast(attention_scores, mstype.float32)
attention_probs = self.softmax(attention_scores)
attention_probs = self.softmax_cast(
attention_probs, self.get_dtype(key_layer))
attention_probs = self.softmax_cast(attention_probs, mstype.float16)
if self.use_dropout:
attention_probs = self.dropout(attention_probs)
@ -212,11 +219,8 @@ class MultiheadAttention(nn.Cell):
class TransformerEncoderLayer(nn.Cell):
"""ipt"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu"):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, compute_type=mstype.float16):
super().__init__()
self.self_attn = MultiheadAttention(q_tensor_width=d_model,
k_tensor_width=d_model,
v_tensor_width=d_model,
@ -224,12 +228,12 @@ class TransformerEncoderLayer(nn.Cell):
out_tensor_width=d_model,
num_attention_heads=nhead,
attention_probs_dropout_prob=dropout)
self.linear1 = nn.Dense(d_model, dim_feedforward)
self.linear1 = nn.Dense(d_model, dim_feedforward).to_float(compute_type)
self.dropout = nn.Dropout(1. - dropout)
self.linear2 = nn.Dense(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm([d_model])
self.norm2 = nn.LayerNorm([d_model])
self.norm1 = LayerPreprocess(d_model)
self.norm2 = LayerPreprocess(d_model)
self.dropout1 = nn.Dropout(1. - dropout)
self.dropout2 = nn.Dropout(1. - dropout)
self.reshape = P.Reshape()
@ -237,7 +241,6 @@ class TransformerEncoderLayer(nn.Cell):
self.activation = P.ReLU()
def with_pos_embed(self, tensor, pos):
"""ipt"""
return tensor if pos is None else tensor + pos
def construct(self, src, pos=None):
@ -258,10 +261,8 @@ class TransformerEncoderLayer(nn.Cell):
class TransformerDecoderLayer(nn.Cell):
"""ipt"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu"):
""" ipt"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.self_attn = MultiheadAttention(q_tensor_width=d_model,
k_tensor_width=d_model,
@ -281,9 +282,9 @@ class TransformerDecoderLayer(nn.Cell):
self.dropout = nn.Dropout(1. - dropout)
self.linear2 = nn.Dense(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm([d_model])
self.norm2 = nn.LayerNorm([d_model])
self.norm3 = nn.LayerNorm([d_model])
self.norm1 = LayerPreprocess(d_model)
self.norm2 = LayerPreprocess(d_model)
self.norm3 = LayerPreprocess(d_model)
self.dropout1 = nn.Dropout(1. - dropout)
self.dropout2 = nn.Dropout(1. - dropout)
self.dropout3 = nn.Dropout(1. - dropout)
@ -291,7 +292,6 @@ class TransformerDecoderLayer(nn.Cell):
self.activation = P.ReLU()
def with_pos_embed(self, tensor, pos):
"""ipt"""
return tensor if pos is None else tensor + pos
def construct(self, tgt, memory, pos=None, query_pos=None):
@ -306,7 +306,7 @@ class TransformerDecoderLayer(nn.Cell):
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(tensor_q=self.with_pos_embed(tgt2, query_pos),
tensor_k=self.with_pos_embed(memory, pos),
tensor_v=memory,)
tensor_v=memory)
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.reshape(tgt2, permute_linear)
@ -318,47 +318,38 @@ class TransformerDecoderLayer(nn.Cell):
class TransformerEncoder(nn.Cell):
"""ipt"""
def __init__(self, encoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
def construct(self, src, pos=None):
"""ipt"""
output = src
for layer in self.layers:
output = layer(output, pos=pos)
return output
class TransformerDecoder(nn.Cell):
"""ipt"""
def __init__(self, decoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
def construct(self, tgt, memory, pos=None, query_pos=None):
"""ipt"""
output = tgt
for layer in self.layers:
output = layer(output, memory, pos=pos, query_pos=query_pos)
return output
def _get_clones(module, N):
"""ipt"""
return nn.CellList([copy.deepcopy(module) for i in range(N)])
def _get_clones(module, n):
return nn.CellList([copy.deepcopy(module) for i in range(n)])
class LearnedPositionalEncoding(nn.Cell):
"""ipt"""
def __init__(self, max_position_embeddings, embedding_dim, seq_length):
super(LearnedPositionalEncoding, self).__init__()
self.pe = nn.Embedding(
@ -370,8 +361,7 @@ class LearnedPositionalEncoding(nn.Cell):
self.position_ids = self.reshape(
self.position_ids, (1, self.seq_length))
def construct(self, x, position_ids=None):
"""ipt"""
def construct(self, position_ids=None):
if position_ids is None:
position_ids = self.position_ids[:, : self.seq_length]
@ -381,46 +371,35 @@ class LearnedPositionalEncoding(nn.Cell):
class VisionTransformer(nn.Cell):
"""ipt"""
def __init__(
self,
img_dim,
patch_dim,
num_channels,
embedding_dim,
num_heads,
num_layers,
hidden_dim,
num_queries,
idx,
positional_encoding_type="learned",
dropout_rate=0,
norm=False,
mlp=False,
pos_every=False,
no_pos=False
):
def __init__(self,
img_dim,
patch_dim,
num_channels,
embedding_dim,
num_heads,
num_layers,
hidden_dim,
num_queries,
dropout_rate=0,
norm=False,
mlp=False,
pos_every=False,
no_pos=False,
con_loss=False):
super(VisionTransformer, self).__init__()
assert embedding_dim % num_heads == 0
assert img_dim % patch_dim == 0
self.norm = norm
self.mlp = mlp
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.patch_dim = patch_dim
self.num_channels = num_channels
self.img_dim = img_dim
self.pos_every = pos_every
self.num_patches = int((img_dim // patch_dim) ** 2)
self.seq_length = self.num_patches
self.flatten_dim = patch_dim * patch_dim * num_channels
self.out_dim = patch_dim * patch_dim * num_channels
self.no_pos = no_pos
self.unf = _unfold_(patch_dim)
self.fold = _fold_(patch_dim, output_shape=(img_dim, img_dim))
@ -432,8 +411,7 @@ class VisionTransformer(nn.Cell):
nn.Dropout(1. - dropout_rate),
nn.ReLU(),
nn.Dense(hidden_dim, self.out_dim),
nn.Dropout(1. - dropout_rate)
)
nn.Dropout(1. - dropout_rate))
self.query_embed = nn.Embedding(
num_queries, embedding_dim * self.seq_length)
@ -449,55 +427,54 @@ class VisionTransformer(nn.Cell):
self.tile = P.Tile()
self.transpose = P.Transpose()
if not self.no_pos:
self.position_encoding = LearnedPositionalEncoding(
self.seq_length, self.embedding_dim, self.seq_length
)
self.position_encoding = LearnedPositionalEncoding(self.seq_length, self.embedding_dim, self.seq_length)
self.dropout_layer1 = nn.Dropout(1. - dropout_rate)
self.query_idx = idx
self.query_idx_tensor = Tensor(idx, mstype.int32)
def construct(self, x):
self.con_loss = con_loss
def construct(self, x, query_idx_tensor):
"""ipt"""
B, _, _, _ = x.shape
x = self.unf(x)
B, N, _ = x.shape
b, n, _ = x.shape
if self.mlp is not True:
x = self.reshape(x, (B * N, -1))
x = self.reshape(x, (b * n, -1))
x = self.dropout_layer1(self.linear_encoding(x)) + x
x = self.reshape(x, (B, N, -1))
x = self.reshape(x, (b, n, -1))
query_embed = self.tile(
self.reshape(self.query_embed(self.query_idx_tensor), (1, self.seq_length, self.embedding_dim)),
(B, 1, 1))
self.reshape(self.query_embed(query_idx_tensor), (1, self.seq_length, self.embedding_dim)), (b, 1, 1))
if not self.no_pos:
pos = self.position_encoding(x)
pos = self.position_encoding()
x = self.encoder(x + pos)
else:
x = self.encoder(x)
x = self.decoder(x, x, query_pos=query_embed)
if self.mlp is not True:
x = self.reshape(x, (B * N, -1))
x = self.reshape(x, (b * n, -1))
x = self.mlp_head(x) + x
x = self.reshape(x, (B, N, -1))
x = self.reshape(x, (b, n, -1))
if self.con_loss:
con_x = x
x = self.fold(x)
return x, con_x
x = self.fold(x)
return x
def default_conv(in_channels, out_channels, kernel_size, has_bias=True):
"""ipt"""
return nn.Conv2d(
in_channels, out_channels, kernel_size, has_bias=has_bias)
return nn.Conv2d(in_channels, out_channels, kernel_size, has_bias=has_bias)
class MeanShift(nn.Conv2d):
"""ipt"""
def __init__(
self, rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
def __init__(self,
rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040),
rgb_std=(1.0, 1.0, 1.0),
sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
self.reshape = P.Reshape()
self.eye = P.Eye()
@ -512,10 +489,14 @@ class MeanShift(nn.Conv2d):
class ResBlock(nn.Cell):
"""ipt"""
def __init__(
self, conv, n_feats, kernel_size,
bias=True, bn=False, act=nn.ReLU(), res_scale=1):
def __init__(self,
conv,
n_feats,
kernel_size,
bias=True,
bn=False,
act=nn.ReLU(),
res_scale=1):
super(ResBlock, self).__init__()
m = []
@ -532,35 +513,28 @@ class ResBlock(nn.Cell):
self.mul = P.Mul()
def construct(self, x):
"""ipt"""
res = self.mul(self.body(x), self.res_scale)
res += x
return res
def _pixelsf_(x, scale):
"""ipt"""
N, C, iH, iW = x.shape
oH = iH * scale
oW = iW * scale
oC = C // (scale ** 2)
output = P.Reshape()(x, (N, oC, scale, scale, iH, iW))
output = P.Transpose()(output, (0, 1, 5, 3, 4, 2))
output = P.Reshape()(output, (N, oC, oH, oW))
output = P.Transpose()(output, (0, 1, 3, 2))
n, c, ih, iw = x.shape
oh = ih * scale
ow = iw * scale
oc = c // (scale ** 2)
output = P.Transpose()(x, (0, 2, 1, 3))
output = P.Reshape()(output, (n, ih, oc*scale, scale, iw))
output = P.Transpose()(output, (0, 1, 2, 4, 3))
output = P.Reshape()(output, (n, ih, oc, scale, ow))
output = P.Transpose()(output, (0, 2, 1, 3, 4))
output = P.Reshape()(output, (n, oc, oh, ow))
return output
class SmallUpSampler(nn.Cell):
"""ipt"""
def __init__(self, conv, upsize, n_feats, bn=False, act=False, bias=True):
def __init__(self, conv, upsize, n_feats, bias=True):
super(SmallUpSampler, self).__init__()
self.conv = conv(n_feats, upsize * upsize * n_feats, 3, bias)
self.reshape = P.Reshape()
@ -568,7 +542,6 @@ class SmallUpSampler(nn.Cell):
self.pixelsf = _pixelsf_
def construct(self, x):
"""ipt"""
x = self.conv(x)
output = self.pixelsf(x, self.upsize)
return output
@ -576,47 +549,37 @@ class SmallUpSampler(nn.Cell):
class Upsampler(nn.Cell):
"""ipt"""
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
def __init__(self, conv, scale, n_feats, bias=True):
super(Upsampler, self).__init__()
m = []
if (scale & (scale - 1)) == 0:
for _ in range(int(math.log(scale, 2))):
m.append(SmallUpSampler(conv, 2, n_feats, bias=bias))
elif scale == 3:
m.append(SmallUpSampler(conv, 3, n_feats, bias=bias))
self.net = nn.SequentialCell(m)
def construct(self, x):
"""ipt"""
return self.net(x)
class IPT(nn.Cell):
"""ipt"""
def __init__(self, args, conv=default_conv):
super(IPT, self).__init__()
self.dytpe = mstype.float16
self.scale_idx = 0
self.args = args
self.con_loss = args.con_loss
n_feats = args.n_feats
kernel_size = 3
act = nn.ReLU()
self.sub_mean = MeanShift(args.rgb_range)
self.add_mean = MeanShift(args.rgb_range, sign=1)
self.head = nn.CellList([
nn.SequentialCell(
conv(args.n_colors, n_feats, kernel_size),
ResBlock(conv, n_feats, 5, act=act),
ResBlock(conv, n_feats, 5, act=act)
) for _ in args.scale
])
nn.SequentialCell(conv(args.n_colors, n_feats, kernel_size).to_float(self.dytpe),
ResBlock(conv, n_feats, 5, act=act).to_float(self.dytpe),
ResBlock(conv, n_feats, 5, act=act).to_float(self.dytpe)) for _ in range(6)])
self.body = VisionTransformer(img_dim=args.patch_size,
patch_dim=args.patch_dim,
@ -630,36 +593,34 @@ class IPT(nn.Cell):
mlp=args.no_mlp,
pos_every=args.pos_every,
no_pos=args.no_pos,
idx=self.scale_idx)
con_loss=args.con_loss).to_float(self.dytpe)
self.tail = nn.CellList([
nn.SequentialCell(
Upsampler(conv, s, n_feats, act=False),
conv(n_feats, args.n_colors, kernel_size)
) for s in args.scale
])
nn.SequentialCell(Upsampler(conv, s, n_feats).to_float(self.dytpe),
conv(n_feats, args.n_colors, kernel_size).to_float(self.dytpe)) \
for s in [2, 3, 4, 1, 1, 1]])
self.reshape = P.Reshape()
self.tile = P.Tile()
self.transpose = P.Transpose()
self.s2t = P.ScalarToTensor()
self.cast = P.Cast()
def construct(self, x):
def construct(self, x, idx):
"""ipt"""
x = self.sub_mean(x)
x = self.head[self.scale_idx](x)
res = self.body(x)
idx_num = idx.shape[0]
x = self.head[idx_num](x)
idx_tensor = self.cast(self.s2t(idx_num), mstype.int32)
if self.con_loss:
res, x_con = self.body(x, idx_tensor)
res += x
x = self.tail[idx_num](x)
return x, x_con
res = self.body(x, idx_tensor)
res += x
x = self.tail[self.scale_idx](res)
x = self.add_mean(x)
x = self.tail[idx_num](res)
return x
def set_scale(self, scale_idx):
"""ipt"""
self.body.query_idx = scale_idx
self.scale_idx = scale_idx
class IPT_post():
"""ipt"""
def __init__(self, model, args):
@ -674,17 +635,13 @@ class IPT_post():
self.cc_2 = P.Concat(axis=2)
self.cc_3 = P.Concat(axis=3)
def set_scale(self, scale_idx):
"""ipt"""
self.body.query_idx = scale_idx
self.scale_idx = scale_idx
def forward(self, x, shave=12, batchsize=64):
def forward(self, x, idx, shave=12, batchsize=64):
"""ipt"""
self.idx = idx
h, w = x.shape[-2:]
padsize = int(self.args.patch_size)
shave = int(self.args.patch_size / 4)
scale = self.args.scale[self.scale_idx]
scale = self.args.scale[0]
h_cut = (h - padsize) % (padsize - shave)
w_cut = (w - padsize) % (padsize - shave)
@ -692,7 +649,7 @@ class IPT_post():
x_unfold = unf_1.compute(x)
x_unfold = self.transpose(x_unfold, (1, 0, 2)) # transpose(0,2)
x_hw_cut = x[:, :, (h - padsize):, (w - padsize):]
y_hw_cut = self.model(x_hw_cut)
y_hw_cut = self.model(x_hw_cut, self.idx)
x_h_cut = x[:, :, (h - padsize):, :]
x_w_cut = x[:, :, :, (w - padsize):]
@ -714,10 +671,10 @@ class IPT_post():
for i in range(x_range):
if i == 0:
y_unfold = self.model(
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)
else:
y_unfold = self.cc_0((y_unfold, self.model(
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)))
y_unf_shape_0 = y_unfold.shape[0]
fold_1 = \
_stride_fold_(padsize * scale, output_shape=((h - h_cut) * scale, (w - w_cut) * scale),
@ -740,17 +697,18 @@ class IPT_post():
stride=padsize * scale - shave * scale)
y_inter = fold_2.compute(self.transpose(self.reshape(
y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1)))
concat1 = self.cc_2((y[:, :, :int(shave / 2 * scale), int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter)) #pylint: disable=line-too-long
concat2 = self.cc_2((concat1, y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)])) #pylint: disable=line-too-long
concat1 = self.cc_2((y[:, :, :int(shave / 2 * scale), \
int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter))
concat2 = self.cc_2((concat1, y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, \
int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)]))
concat3 = self.cc_3((y[:, :, :, :int(shave / 2 * scale)], concat2))
y = self.cc_3((concat3, y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) #pylint: disable=line-too-long
y = self.cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :], y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) #pylint: disable=line-too-long
y = self.cc_3((concat3, y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):]))
y = self.cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :],
y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :]))
y_w_cat = self.cc_2((y_w_cut[:, :, :y_w_cut.shape[2] - int((padsize - h_cut) / 2 * scale), :],
y_hw_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :]))
y = self.cc_3((y[:, :, :, :y.shape[3] - int((padsize - w_cut) / 2 * scale)],
y_w_cat[:, :, :, int((padsize - w_cut) / 2 * scale + 0.5):]))
return y
def cut_h_new(self, x_h_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize):
@ -766,11 +724,11 @@ class IPT_post():
for i in range(x_range):
if i == 0:
y_h_cut_unfold = self.model(
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)
else:
y_h_cut_unfold = \
self.cc_0((y_h_cut_unfold, self.model(
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)))
y_h_cut_unfold_shape_0 = y_h_cut_unfold.shape[0]
fold_1 = \
_stride_fold_(padsize * scale, output_shape=(padsize * scale, (w - w_cut) * scale),
@ -802,10 +760,11 @@ class IPT_post():
for i in range(x_range):
if i == 0:
y_w_cut_unfold = self.model(
x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)
else:
y_w_cut_unfold = self.cc_0((y_w_cut_unfold,
self.model(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
self.model(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], \
self.idx)))
y_w_cut_unfold_shape_0 = y_w_cut_unfold.shape[0]
fold_1 = _stride_fold_(padsize * scale,
output_shape=((h - h_cut) * scale,
@ -827,7 +786,6 @@ class IPT_post():
class _stride_unfold_():
'''stride'''
def __init__(self,
kernel_size,
stride=-1):
@ -874,13 +832,12 @@ class _stride_unfold_():
zeros4 = np.zeros(unf_x[:, :, :, unf_j + self.kernel_size:].shape)
concat4 = np.concatenate((concat3, zeros4), axis=3)
unf_x += concat4
unf_x = Tensor(unf_x, mstype.float32)
unf_x = Tensor(unf_x, mstype.float16)
y = self.unfold(unf_x)
return y
class _stride_fold_():
'''stride'''
def __init__(self,
kernel_size,
output_shape=(-1, -1),
@ -905,7 +862,7 @@ class _stride_fold_():
self.fold = _fold_(self.kernel_size, self.large_shape)
def compute(self, x):
'''stride'''
""" compute"""
NumBlock_x = self.NumBlock_x
NumBlock_y = self.NumBlock_y
large_x = self.fold(x)
@ -917,7 +874,8 @@ class _stride_fold_():
leftup_idx_x.append(i * self.kernel_size[0])
for i in range(NumBlock_y):
leftup_idx_y.append(i * self.kernel_size[1])
fold_x = np.zeros((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], (NumBlock_y - 1) * self.stride + self.kernel_size[1]), dtype=np.float32) #pylint: disable=line-too-long
fold_x = np.zeros((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], \
(NumBlock_y - 1) * self.stride + self.kernel_size[1]), dtype=np.float32)
for i in range(NumBlock_x):
for j in range(NumBlock_y):
fold_i = i * self.stride
@ -938,12 +896,11 @@ class _stride_fold_():
zeros4 = np.zeros(t4.shape)
concat4 = np.concatenate((concat3, zeros4), axis=3)
fold_x += concat4
y = Tensor(fold_x, mstype.float32)
y = Tensor(fold_x, mstype.float16)
return y
class _unfold_(nn.Cell):
"""ipt"""
def __init__(
self, kernel_size, stride=-1):
@ -965,8 +922,10 @@ class _unfold_(nn.Cell):
output_img = self.reshape(x, (N, C, numH, self.kernel_size, W))
output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
output_img = self.reshape(output_img, (N, C, numH, -1, self.kernel_size, self.kernel_size))
output_img = self.transpose(output_img, (0, 2, 3, 1, 5, 4))
output_img = self.reshape(output_img, (N*C, numH, numW, self.kernel_size, self.kernel_size))
output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
output_img = self.reshape(output_img, (N, C, numH * numW, self.kernel_size*self.kernel_size))
output_img = self.transpose(output_img, (0, 2, 1, 3))
output_img = self.reshape(output_img, (N, numH * numW, -1))
return output_img
@ -1002,14 +961,10 @@ class _fold_(nn.Cell):
org_W = self.output_shape[1]
numH = org_H // self.kernel_size[0]
numW = org_W // self.kernel_size[1]
output_img = self.reshape(
x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))
output_img = self.reshape(x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))
output_img = self.transpose(output_img, (0, 2, 3, 1, 4))
output_img = self.reshape(output_img, (N*org_C, self.kernel_size[0], numH, numW, self.kernel_size[1]))
output_img = self.transpose(output_img, (0, 2, 1, 3, 4))
output_img = self.reshape(
output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1]))
output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5))
output_img = self.reshape(output_img, (N, org_C, org_H, org_W))
return output_img

View File

@ -0,0 +1,125 @@
"""loss"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
class SupConLoss(nn.Cell):
"""SupConLoss"""
def __init__(self, temperature=0.07, contrast_mode='all',
base_temperature=0.07):
super(SupConLoss, self).__init__()
self.temperature = temperature
self.contrast_mode = contrast_mode
self.base_temperature = base_temperature
self.normalize = P.L2Normalize(axis=2)
self.eye = P.Eye()
self.unbind = P.Unstack(axis=1)
self.cat = P.Concat(axis=0)
self.matmul = P.MatMul()
self.div = P.Div()
self.transpose = P.Transpose()
self.maxes = P.ArgMaxWithValue(axis=1, keep_dims=True)
self.tile = P.Tile()
self.scatter = P.ScatterNd()
self.oneslike = P.OnesLike()
self.exp = P.Exp()
self.sum = P.ReduceSum(keep_dims=True)
self.log = P.Log()
self.reshape = P.Reshape()
self.mean = P.ReduceMean()
def construct(self, features):
"""SupConLoss"""
features = self.normalize(features)
batch_size = features.shape[0]
mask = self.eye(batch_size, batch_size, mstype.float32)
contrast_count = features.shape[1]
contrast_feature = self.cat(self.unbind(features))
if self.contrast_mode == 'all':
anchor_feature = contrast_feature
anchor_count = contrast_count
else:
anchor_feature = features[:, 0]
anchor_count = 1
anchor_dot_contrast = self.div(self.matmul(anchor_feature, self.transpose(contrast_feature, (1, 0))), \
self.temperature)
_, logits_max = self.maxes(anchor_dot_contrast)
logits = anchor_dot_contrast - logits_max
mask = self.tile(mask, (anchor_count, contrast_count))
logits_mask = 1 - self.eye(mask.shape[0], mask.shape[1], mstype.float32)
mask = mask * logits_mask
exp_logits = self.exp(logits) * logits_mask
log_prob = logits - self.log(self.sum(exp_logits, 1) + 1e-8)
mean_log_prob_pos = self.sum((mask * log_prob), 1) / self.sum(mask, 1)
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = self.mean(self.reshape(loss, (anchor_count, batch_size)))
return loss, anchor_count
class ClipGradients(nn.Cell):
"""
Clip gradients.
Args:
grads (list): List of gradient tuples.
clip_type (Tensor): The way to clip, 'value' or 'norm'.
clip_value (Tensor): Specifies how much to clip.
Returns:
List, a list of clipped_grad tuples.
"""
def __init__(self):
super(ClipGradients, self).__init__()
self.clip_by_norm = nn.ClipByNorm()
self.cast = P.Cast()
self.dtype = P.DType()
def construct(self, grads, clip_type, clip_value):
"""ClipGradients"""
if clip_type not in (0, 1):
return grads
new_grads = ()
for grad in grads:
dt = self.dtype(grad)
if clip_type == 0:
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
self.cast(F.tuple_to_array((clip_value,)), dt))
else:
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
t = self.cast(t, dt)
new_grads = new_grads + (t,)
return new_grads
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()
@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)

17
model_zoo/research/cv/IPT/src/metrics.py Executable file → Normal file
View File

@ -1,4 +1,4 @@
'''metrics'''
"""metrics"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import math
import numpy as np
def quantize(img, rgb_range):
'''metrics'''
"""metrics"""
pixel_range = 255 / rgb_range
img = np.multiply(img, pixel_range)
img = np.clip(img, 0, 255)
@ -26,15 +26,14 @@ def quantize(img, rgb_range):
return img
def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None):
'''metrics'''
def calc_psnr(sr, hr, scale, rgb_range):
"""metrics"""
hr = np.float32(hr)
sr = np.float32(sr)
diff = (sr - hr) / rgb_range
gray_coeffs = np.array([65.738, 129.057, 25.064]
).reshape((1, 3, 1, 1)) / 256
gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256
diff = np.multiply(diff, gray_coeffs).sum(1)
if np.size(hr) == 1:
if hr.size == 1:
return 0
if scale != 1:
shave = scale
@ -49,7 +48,7 @@ def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None):
def rgb2ycbcr(img, y_only=True):
'''metrics'''
"""metrics"""
img.astype(np.float32)
if y_only:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0

View File

@ -1,67 +0,0 @@
'''temp'''
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
def set_template(args):
'''temp'''
if args.template.find('jpeg') >= 0:
args.data_train = 'DIV2K_jpeg'
args.data_test = 'DIV2K_jpeg'
args.epochs = 200
args.decay = '100'
if args.template.find('EDSR_paper') >= 0:
args.model = 'EDSR'
args.n_resblocks = 32
args.n_feats = 256
args.res_scale = 0.1
if args.template.find('MDSR') >= 0:
args.model = 'MDSR'
args.patch_size = 48
args.epochs = 650
if args.template.find('DDBPN') >= 0:
args.model = 'DDBPN'
args.patch_size = 128
args.scale = '4'
args.data_test = 'Set5'
args.batch_size = 20
args.epochs = 1000
args.decay = '500'
args.gamma = 0.1
args.weight_decay = 1e-4
args.loss = '1*MSE'
if args.template.find('GAN') >= 0:
args.epochs = 200
args.lr = 5e-5
args.decay = '150'
if args.template.find('RCAN') >= 0:
args.model = 'RCAN'
args.n_resgroups = 10
args.n_resblocks = 20
args.n_feats = 64
args.chop = True
if args.template.find('VDSR') >= 0:
args.model = 'VDSR'
args.n_resblocks = 20
args.n_feats = 64
args.patch_size = 41
args.lr = 1e-1

View File

@ -0,0 +1,173 @@
"""utils"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import time
from bisect import bisect_right
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.context import ParallelMode
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.train.serialization import save_checkpoint
from src.loss import SupConLoss
class MyTrain(nn.Cell):
"""MyTrain"""
def __init__(self, model, criterion, con_loss, use_con=True):
super(MyTrain, self).__init__(auto_prefix=True)
self.use_con = use_con
self.model = model
self.con_loss = con_loss
self.criterion = criterion
self.p = P.Print()
self.cast = P.Cast()
def construct(self, lr, hr, idx):
"""MyTrain"""
if self.use_con:
sr, x_con = self.model(lr, idx)
x_con = self.cast(x_con, mstype.float32)
sr = self.cast(sr, mstype.float32)
loss1 = self.criterion(sr, hr)
loss2 = self.con_loss(x_con)
loss = loss1 + 0.1 * loss2
else:
sr = self.model(lr, idx)
sr = self.cast(sr, mstype.float32)
loss = self.criterion(sr, hr)
return loss
class MyTrainOneStepCell(nn.Cell):
"""MyTrainOneStepCell"""
def __init__(self, network, optimizer, sens=1.0):
super(MyTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
self.grad_reducer = DistributedGradReducer(optimizer.parameters, True, 8)
def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
def sub_mean(x):
red_channel_mean = 0.4488 * 255
green_channel_mean = 0.4371 * 255
blue_channel_mean = 0.4040 * 255
x[:, 0, :, :] -= red_channel_mean
x[:, 1, :, :] -= green_channel_mean
x[:, 2, :, :] -= blue_channel_mean
return x
def add_mean(x):
red_channel_mean = 0.4488 * 255
green_channel_mean = 0.4371 * 255
blue_channel_mean = 0.4040 * 255
x[:, 0, :, :] += red_channel_mean
x[:, 1, :, :] += green_channel_mean
x[:, 2, :, :] += blue_channel_mean
return x
class Trainer():
"""Trainer"""
def __init__(self, args, loader, my_model):
self.args = args
self.scale = args.scale
self.trainloader = loader
self.model = my_model
self.model.set_train()
self.criterion = nn.L1Loss()
self.con_loss = SupConLoss()
self.optimizer = nn.Adam(self.model.trainable_params(), learning_rate=args.lr, loss_scale=1024.0)
self.train_net = MyTrain(self.model, self.criterion, self.con_loss, use_con=args.con_loss)
self.bp = MyTrainOneStepCell(self.train_net, self.optimizer, 1024.0)
def train(self):
"""Trainer"""
losses = 0
for batch_idx, imgs in enumerate(self.trainloader):
lr = imgs["LR"]
hr = imgs["HR"]
lr = Tensor(sub_mean(lr), mstype.float32)
hr = Tensor(sub_mean(hr), mstype.float32)
idx = Tensor(np.ones(imgs["idx"][0]), mstype.int32)
t1 = time.time()
loss = self.bp(lr, hr, idx)
t2 = time.time()
losses += loss.asnumpy()
print('Task: %g, Step: %g, loss: %f, time: %f s' % (idx.shape[0], batch_idx, loss.asnumpy(), t2 - t1),
flush=True)
os.makedirs(self.args.save, exist_ok=True)
if self.args.rank == 0:
save_checkpoint(self.bp, self.args.save + "model_" + str(self.epoch) + '.ckpt')
def update_learning_rate(self, epoch):
"""Update learning rates for all the networks; called at the end of every epoch.
:param epoch: current epoch
:type epoch: int
:param lr: learning rate of cyclegan
:type lr: float
:param niter: number of epochs with the initial learning rate
:type niter: int
:param niter_decay: number of epochs to linearly decay learning rate to zero
:type niter_decay: int
"""
self.epoch = epoch
value = self.args.decay.split('-')
value.sort(key=int)
milestones = list(map(int, value))
print("*********** epoch: {} **********".format(epoch))
lr = self.args.lr * self.args.gamma ** bisect_right(milestones, epoch)
self.adjust_lr('model', self.optimizer, lr)
print("*********************************")
def adjust_lr(self, name, optimizer, lr):
"""Adjust learning rate for the corresponding model.
:param name: name of model
:type name: str
:param optimizer: the optimizer of the corresponding model
:type optimizer: torch.optim
:param lr: learning rate to be adjusted
:type lr: float
"""
lr_param = optimizer.get_lr()
lr_param.assign_value(Tensor(lr, mstype.float32))
print('==> ' + name + ' learning rate: ', lr_param.asnumpy())

View File

@ -0,0 +1,154 @@
"""train"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import math
import mindspore.dataset as ds
from mindspore import Parameter, set_seed, context
from mindspore.context import ParallelMode
from mindspore.common.initializer import initializer, HeUniform, XavierUniform, Uniform, Normal, Zero
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.args import args
from src.data.bicubic import bicubic
from src.data.imagenet import ImgData
from src.ipt_model import IPT
from src.utils import Trainer
def _calculate_fan_in_and_fan_out(shape):
"""
calculate fan_in and fan_out
Args:
shape (tuple): input shape.
Returns:
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
"""
dimensions = len(shape)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
if dimensions == 2:
fan_in = shape[1]
fan_out = shape[0]
else:
num_input_fmaps = shape[1]
num_output_fmaps = shape[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = shape[2] * shape[3]
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def init_weights(net, init_type='normal', init_gain=0.02):
"""
Initialize network weights.
:param net: network to be initialized
:type net: nn.Module
:param init_type: the name of an initialization method: normal | xavier | kaiming | orthogonal
:type init_type: str
:param init_gain: scaling factor for normal, xavier and orthogonal.
:type init_gain: float
"""
for _, cell in net.cells_and_names():
classname = cell.__class__.__name__
if hasattr(cell, 'in_proj_layer'):
cell.in_proj_layer = Parameter(initializer(HeUniform(negative_slope=math.sqrt(5)), cell.in_proj_layer.shape,
cell.in_proj_layer.dtype), name=cell.in_proj_layer.name)
if hasattr(cell, 'weight'):
if init_type == 'normal':
cell.weight = Parameter(initializer(Normal(init_gain), cell.weight.shape,
cell.weight.dtype), name=cell.weight.name)
elif init_type == 'xavier':
cell.weight = Parameter(initializer(XavierUniform(init_gain), cell.weight.shape,
cell.weight.dtype), name=cell.weight.name)
elif init_type == "he":
cell.weight = Parameter(initializer(HeUniform(negative_slope=math.sqrt(5)), cell.weight.shape,
cell.weight.dtype), name=cell.weight.name)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(cell, 'bias') and cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.shape)
bound = 1 / math.sqrt(fan_in)
cell.bias = Parameter(initializer(Uniform(bound), cell.bias.shape, cell.bias.dtype),
name=cell.bias.name)
elif classname.find('BatchNorm2d') != -1:
cell.gamma = Parameter(initializer(Normal(1.0), cell.gamma.default_input.shape()), name=cell.gamma.name)
cell.beta = Parameter(initializer(Zero(), cell.beta.default_input.shape()), name=cell.beta.name)
print('initialize network weight with %s' % init_type)
def train_net(distribute, imagenet, epochs):
"""Train net"""
set_seed(1)
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
if imagenet == 1:
train_dataset = ImgData(args)
else:
train_dataset = data.Data(args).loader_train
if distribute:
init()
rank_id = get_rank()
rank_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=rank_size, gradients_mean=True)
print('Rank {}, rank_size {}'.format(rank_id, rank_size))
if imagenet == 1:
train_de_dataset = ds.GeneratorDataset(train_dataset,
["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
num_shards=rank_size, shard_id=args.rank, shuffle=True)
else:
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=rank_size,
shard_id=rank_id, shuffle=True)
else:
if imagenet == 1:
train_de_dataset = ds.GeneratorDataset(train_dataset,
["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
shuffle=True)
else:
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], shuffle=True)
resize_fuc = bicubic()
train_de_dataset = train_de_dataset.project(columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"])
train_de_dataset = train_de_dataset.batch(args.batch_size,
input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"],
output_columns=["LR", "HR", "idx", "filename"],
drop_remainder=True, per_batch_map=resize_fuc.forward)
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
net_work = IPT(args)
init_weights(net_work, init_type='he', init_gain=1.0)
print("Init net weight successfully")
if args.pth_path:
param_dict = load_checkpoint(args.pth_path)
load_param_into_net(net_work, param_dict)
print("Load net weight successfully")
train_func = Trainer(args, train_loader, net_work)
for epoch in range(0, epochs):
train_func.update_learning_rate(epoch)
train_func.train()
if __name__ == '__main__':
train_net(distribute=args.distribute, imagenet=args.imagenet, epochs=args.epochs)

View File

@ -0,0 +1,95 @@
"""train finetune"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
from mindspore import context
from mindspore.context import ParallelMode
import mindspore.dataset as ds
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.common import set_seed
from src.args import args
from src.data.imagenet import ImgData
from src.data.srdata import SRData
from src.data.div2k import DIV2K
from src.data.bicubic import bicubic
from src.ipt_model import IPT
from src.utils import Trainer
def train_net(distribute, imagenet):
"""Train net with finetune"""
set_seed(1)
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
if imagenet == 1:
train_dataset = ImgData(args)
elif not args.derain:
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
train_dataset.set_scale(args.task_id)
else:
train_dataset = SRData(args, name=args.data_train, train=True, benchmark=False)
train_dataset.set_scale(args.task_id)
if distribute:
init()
rank_id = get_rank()
rank_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=rank_size, gradients_mean=True)
print('Rank {}, group_size {}'.format(rank_id, rank_size))
if imagenet == 1:
train_de_dataset = ds.GeneratorDataset(train_dataset,
["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
num_shards=rank_size, shard_id=rank_id, shuffle=True)
else:
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR", "idx", "filename"],
num_shards=rank_size, shard_id=rank_id, shuffle=True)
else:
if imagenet == 1:
train_de_dataset = ds.GeneratorDataset(train_dataset,
["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
shuffle=True)
else:
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR", "idx", "filename"], shuffle=True)
if args.imagenet == 1:
resize_fuc = bicubic()
train_de_dataset = train_de_dataset.batch(
args.batch_size,
input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
output_columns=["LR", "HR", "idx", "filename"], drop_remainder=True,
per_batch_map=resize_fuc.forward)
else:
train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
net_m = IPT(args)
print("Init net weights successfully")
if args.pth_path:
param_dict = load_checkpoint(args.pth_path)
load_param_into_net(net_m, param_dict)
print("Load net weight successfully")
train_func = Trainer(args, train_loader, net_m)
for epoch in range(0, args.epochs):
train_func.update_learning_rate(epoch)
train_func.train()
if __name__ == "__main__":
train_net(distribute=args.distribute, imagenet=args.imagenet)