forked from mindspore-Ecosystem/mindspore
!14981 add Advanced EAST into model_zoo
From: @leidaowaijiao Reviewed-by: Signed-off-by:
This commit is contained in:
commit
a1e19320bc
|
@ -1,42 +1,56 @@
|
|||
# AdvancedEAST
|
||||
# Contents
|
||||
|
||||
## introduction
|
||||
- [Advanced East Description](#advancedeast-description)
|
||||
- [Environment](#environment)
|
||||
- [Dependences](#dependences)
|
||||
- [Project Files](#project-files)
|
||||
- [Dataset](#dataset)
|
||||
- [Run The Project](#run-the-project)
|
||||
- [Data Preprocess](#data-preprocess)
|
||||
- [Training Process](#training-process)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#training-performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [Advanced East Description](#contents)
|
||||
|
||||
AdvancedEAST is inspired by EAST [EAST:An Efficient and Accurate Scene Text Detector](https://arxiv.org/abs/1704.03155v2).
|
||||
The architecture of AdvancedEAST is showed below![AdvancedEast network arch](AdvancedEast.network.png).
|
||||
The architecture of AdvancedEAST is showed below.![AdvancedEast network arch](AdvancedEast.network.png)
|
||||
This project is inherited by [huoyijie/AdvancedEAST](https://github.com/huoyijie/AdvancedEAST)(preprocess, network architecture, predict) and [BaoWentz/AdvancedEAST-PyTorch](https://github.com/BaoWentz/AdvancedEAST-PyTorch)(performance).
|
||||
|
||||
## Environment
|
||||
# [Environment](#contents)
|
||||
|
||||
* euleros v2r7 x86_64
|
||||
* python 3.7.5
|
||||
- euleros v2r7 x86_64\Ubuntu 16.04
|
||||
- python 3.7.5
|
||||
|
||||
## Dependences
|
||||
# [Dependences](#contents)
|
||||
|
||||
* mindspore==1.1
|
||||
* shapely==1.7.1
|
||||
* numpy==1.19.4
|
||||
* tqdm==4.36.1
|
||||
- mindspore==1.2.0
|
||||
- shapely==1.7.1
|
||||
- numpy==1.19.4
|
||||
- tqdm==4.36.1
|
||||
|
||||
## Project files
|
||||
# [Project Files](#contents)
|
||||
|
||||
* configuration of file
|
||||
- configuration of file
|
||||
cfg.py, control parameters
|
||||
* pre-process data:
|
||||
- pre-process data:
|
||||
preprocess.py, resize image
|
||||
* label data:
|
||||
- label data:
|
||||
label.py,produce label info
|
||||
* define network
|
||||
- define network
|
||||
model.py and VGG.py
|
||||
* define loss function
|
||||
- define loss function
|
||||
losses.py
|
||||
* execute training
|
||||
- execute training
|
||||
advanced_east.py and dataset.py
|
||||
* predict
|
||||
- predict
|
||||
predict.py and nms.py
|
||||
* scoring
|
||||
- scoring
|
||||
score.py
|
||||
* logging
|
||||
- logging
|
||||
logger.py
|
||||
|
||||
```shell
|
||||
|
@ -44,10 +58,11 @@ This project is inherited by [huoyijie/AdvancedEAST](https://github.com/huoyijie
|
|||
└──advanced_east
|
||||
├── README.md
|
||||
├── scripts
|
||||
├── run_distribute_train_gpu.sh # launch ascend distributed training(8 pcs)
|
||||
├── run_distribute_train_ascend.sh # launch ascend distributed training(8 pcs)
|
||||
├── run_standalone_train_ascend.sh # launch ascend standalone training(1 pcs)
|
||||
├── run_distribute_train_gpu.sh # launch gpu distributed training(8 pcs)
|
||||
└── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs)
|
||||
└── eval.sh # evaluate model(1 pcs)
|
||||
├── src
|
||||
├── cfg.py # parameter configuration
|
||||
├── dataset.py # data preprocessing
|
||||
|
@ -58,6 +73,7 @@ This project is inherited by [huoyijie/AdvancedEAST](https://github.com/huoyijie
|
|||
├── predict.py # predict boxes
|
||||
├── preprocess.py # pre-process data
|
||||
└── score.py # scoring
|
||||
└── vgg.py # vgg model
|
||||
├── export.py # export model for inference
|
||||
├── prepare_data.py # exec data preprocessing
|
||||
├── eval.py # eval net
|
||||
|
@ -65,7 +81,7 @@ This project is inherited by [huoyijie/AdvancedEAST](https://github.com/huoyijie
|
|||
└── train_mindrecord.py # train net on user specified mindrecord
|
||||
```
|
||||
|
||||
## dataset
|
||||
# [Dataset](#contents)
|
||||
|
||||
ICPR MTWI 2018 challenge 2:Text detection of network image,[Link](https://tianchi.aliyun.com/competition/entrance/231651/introduction). It is not available to download dataset on the origin webpage,
|
||||
the dataset is now provided by the author of the original project,[Baiduyun link](https://pan.baidu.com/s/1NSyc-cHKV3IwDo6qojIrKA), password: ye9y. There are 10000 images and corresponding label
|
||||
|
@ -80,7 +96,7 @@ of 9:1. If you want to use your own dataset, please modify the configuration of
|
|||
> ```
|
||||
Some parameters in config.py:
|
||||
|
||||
```python
|
||||
```shell
|
||||
'validation_split_ratio': 0.1, # ratio of validation dataset
|
||||
'total_img': 10000, # total number of samples in dataset
|
||||
'data_dir': './icpr/', # dir of dataset
|
||||
|
@ -94,9 +110,9 @@ Some parameters in config.py:
|
|||
'train_label_dir_name': 'labels_train/', # dir which stores the preprocessed text verteices.
|
||||
```
|
||||
|
||||
## Run the project
|
||||
# [Run The Project](#contents)
|
||||
|
||||
### Data preprocess
|
||||
## [Data Preprocess](#contents)
|
||||
|
||||
Resize all the images to fixed size, and convert the label information(the vertex of text box) into the format used in training and evaluation, then the Mindsrecord files are generated.
|
||||
|
||||
|
@ -104,23 +120,30 @@ Resize all the images to fixed size, and convert the label information(the verte
|
|||
python preparedata.py
|
||||
```
|
||||
|
||||
### Training
|
||||
## [Training Process](#contents)
|
||||
|
||||
Prepare the VGG16 pre-training model. Due to copyright restrictions, please go to https://github.com/machrisaa/tensorflow-vgg to download the VGG16 pre-training model and place it in the src folder.
|
||||
If you have the checkpoint of VGG16, you can load the parameters in this way, the training training time can be shorten obviously.
|
||||
|
||||
single 1p)
|
||||
- single Ascend
|
||||
|
||||
```bash
|
||||
python train.py --device_target="Ascend" --is_distributed=0 --device_id=0 > output.train.log 2>&1 &
|
||||
```
|
||||
|
||||
single 1p)specific size
|
||||
- single GPU
|
||||
|
||||
```bash
|
||||
python train.py --device_target="GPU" --is_distributed=0 --device_id=0 > output.train.log 2>&1 &
|
||||
```
|
||||
|
||||
- single device with specific size
|
||||
|
||||
```bash
|
||||
python train_mindrecord.py --device_target="Ascend" --is_distributed=0 --device_id=2 --size=256 > output.train.log 2>&1 &
|
||||
```
|
||||
|
||||
multi Ascends
|
||||
- multi Ascends
|
||||
|
||||
```bash
|
||||
# running on distributed environment(8p)
|
||||
|
@ -129,7 +152,7 @@ bash scripts/run_distribute_train.sh
|
|||
|
||||
The detailed training parameters are in /src/config.py。
|
||||
|
||||
multi GPUs
|
||||
- multi GPUs
|
||||
|
||||
```bash
|
||||
# running on distributed environment(8p)
|
||||
|
@ -153,13 +176,34 @@ config.py:
|
|||
'max_predict_img_size': 448, # max size of the images to predict
|
||||
'ckpt_save_max': 10, # maximum of ckpt in dir
|
||||
'saved_model_file_path': './saved_model/', # dir of saved model
|
||||
'norm': 'BN', # normalization in feature merging branch
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluate
|
||||
The above python command will run in the background, you can view the results through the file output.eval.log. You will get the accuracy as following.
|
||||
You can get loss, accuracy, recall, F1 score and the box vertices of an image.
|
||||
|
||||
The above python command will run in the background, you can view the results through the file output.eval.log. You will get the accuracy as following:
|
||||
- loss
|
||||
|
||||
```bash
|
||||
# evaluate loss of the model
|
||||
bash scripts/run_distribute_train_gpu.sh
|
||||
```
|
||||
|
||||
- score
|
||||
|
||||
```bash
|
||||
# evaluate loss of the model
|
||||
bash scripts/run_distribute_train_gpu.sh
|
||||
```
|
||||
|
||||
- prediction
|
||||
|
||||
```bash
|
||||
# get prediction of an image
|
||||
bash run_eval.sh 0_8-24_1012.ckpt pred ./demo/001.png
|
||||
```
|
||||
|
||||
## Inference Process
|
||||
|
||||
|
@ -191,32 +235,37 @@ Inference result is saved in current path, you can find result in acc.log file.
|
|||
|
||||
## performance
|
||||
|
||||
### Training performance
|
||||
## [Training Performance](#contents)
|
||||
|
||||
The performance listed below are acquired with the default configurations in /src/config.py
|
||||
| Parameters | single Ascend |
|
||||
| -------------------------- | ---------------------------------------------- |
|
||||
| Model Version | AdvancedEAST |
|
||||
| Resources | Ascend 910 |
|
||||
| MindSpore Version | 1.1 |
|
||||
| Dataset | MTWI-2018 |
|
||||
| Training Parameters | epoch=18, batch_size = 8, lr=1e-3 |
|
||||
| Optimizer | AdamWeightDecay |
|
||||
| Loss Function | QuadLoss |
|
||||
| Outputs | matrix with size of 3x64x64,3x96x96,3x112x112 |
|
||||
| Loss | 0.1 |
|
||||
| Total Time | 28mins, 60mins, 90mins |
|
||||
| Checkpoints | 173MB(.ckpt file) |
|
||||
The performance listed below are acquired with the default configurations in /src/config.py.
|
||||
The Normalization of model training on Ascend is GN, the model training on GPU is used BN.
|
||||
| Parameters | single Ascend | 8 GPUs |
|
||||
| -------------------- | ------------------------------------- |---------------------------------------------- |
|
||||
| Model Version | AdvancedEAST | AdvancedEAST |
|
||||
| Resources | Ascend 910 | Tesla V100S-PCIE 32G|
|
||||
| MindSpore Version | 1.1 |1.1 |
|
||||
| Dataset | MTWI-2018 |MTWI-2018 |
|
||||
| Training Parameters | epoch=18, batch_size = 8, lr=1e-3 |epoch=84, batch_size = 8, lr=1e-3 |
|
||||
| Optimizer | AdamWeightDecay |AdamWeightDecay |
|
||||
| Loss Function | QuadLoss |QuadLoss |
|
||||
| Outputs | matrix with size of 3x64x64,3x96x96,3x112x112 |matrix with size of 3x64x64,3x96x96,3x112x112 |
|
||||
| Loss | 0.1 |0.1 |
|
||||
| Total Time | 28 mins, 60 mins, 90 mins | 4.9 mins, 10.3 mins, 14.5 mins
|
||||
| Checkpoints | 173MB(.ckpt file) |173MB(.ckpt file) |
|
||||
|
||||
### Evaluation Performance
|
||||
## [Evaluation Performance](#contents)
|
||||
|
||||
On the default
|
||||
| Parameters | single Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | AdvancedEAST |
|
||||
| Resources | Ascend 910 |
|
||||
| MindSpore Version | 1.1 |
|
||||
| Dataset | 1000 images |
|
||||
| batch_size | 8 |
|
||||
| Outputs | precision, recall, F score |
|
||||
| performance | 94.35, 55.45, 66.31 |
|
||||
| Parameters | single Ascend | 8 GPUs |
|
||||
| ------------------- | --------------------------- |--------------------------- |
|
||||
| Model Version | AdvancedEAST |AdvancedEAST |
|
||||
| Resources | Ascend 910 |Tesla V100S-PCIE 32G|
|
||||
| MindSpore Version | 1.1 | 1.1 |
|
||||
| Dataset | 1000 images |1000 images |
|
||||
| batch_size | 8 | 8 |
|
||||
| Outputs | precision, recall, F score |precision, recall, F score |
|
||||
| performance | 94.35, 55.45, 66.31 | 92.53 55.49 66.01 |
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -21,25 +21,26 @@ import os
|
|||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_param_into_net, load_checkpoint
|
||||
|
||||
from src.logger import get_logger
|
||||
from src.predict import predict
|
||||
from src.score import eval_pre_rec_f1
|
||||
|
||||
from src.config import config as cfg
|
||||
from src.dataset import load_adEAST_dataset
|
||||
from src.model import get_AdvancedEast_net, AdvancedEast
|
||||
from src.preprocess import resize_image
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def parse_args(cloud_args=None):
|
||||
def parse_args():
|
||||
"""parameters"""
|
||||
parser = argparse.ArgumentParser('adveast evaling')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
|
@ -78,11 +79,11 @@ def parse_args(cloud_args=None):
|
|||
return args_opt
|
||||
|
||||
|
||||
def eval_loss(arg):
|
||||
def eval_loss(eval_arg):
|
||||
"""get network and init"""
|
||||
loss_net, train_net = get_AdvancedEast_net()
|
||||
print(os.path.join(arg.saved_model_file_path, arg.ckpt))
|
||||
load_param_into_net(train_net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt)))
|
||||
loss_net, train_net = get_AdvancedEast_net(eval_arg)
|
||||
print(os.path.join(eval_arg.saved_model_file_path, eval_arg.ckpt))
|
||||
load_param_into_net(train_net, load_checkpoint(os.path.join(eval_arg.saved_model_file_path, eval_arg.ckpt)))
|
||||
train_net.set_train(False)
|
||||
loss = 0
|
||||
idx = 0
|
||||
|
@ -92,36 +93,36 @@ def eval_loss(arg):
|
|||
print(loss / idx)
|
||||
|
||||
|
||||
def eval_score(arg):
|
||||
def eval_score(eval_arg):
|
||||
"""get network and init"""
|
||||
net = AdvancedEast()
|
||||
load_param_into_net(net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt)))
|
||||
net = AdvancedEast(eval_arg)
|
||||
load_param_into_net(net, load_checkpoint(os.path.join(eval_arg.saved_model_file_path, eval_arg.ckpt)))
|
||||
net.set_train(False)
|
||||
obj = eval_pre_rec_f1()
|
||||
with open(os.path.join(arg.data_dir, arg.val_fname), 'r') as f_val:
|
||||
with open(os.path.join(eval_arg.data_dir, eval_arg.val_fname), 'r') as f_val:
|
||||
f_list = f_val.readlines()
|
||||
|
||||
img_h, img_w = arg.max_predict_img_size, arg.max_predict_img_size
|
||||
x = np.zeros((arg.batch_size, 3, img_h, img_w), dtype=np.float32)
|
||||
batch_list = np.arange(0, len(f_list), arg.batch_size)
|
||||
for idx in batch_list:
|
||||
img_h, img_w = eval_arg.max_predict_img_size, eval_arg.max_predict_img_size
|
||||
x = np.zeros((eval_arg.batch_size, 3, img_h, img_w), dtype=np.float32)
|
||||
batch_list = np.arange(0, len(f_list), eval_arg.batch_size)
|
||||
for idx in tqdm(batch_list):
|
||||
gt_list = []
|
||||
for i in range(idx, min(idx + arg.batch_size, len(f_list))):
|
||||
for i in range(idx, min(idx + eval_arg.batch_size, len(f_list))):
|
||||
item = f_list[i]
|
||||
img_filename = str(item).strip().split(',')[0][:-4]
|
||||
img_path = os.path.join(arg.train_image_dir_name, img_filename) + '.jpg'
|
||||
img_path = os.path.join(eval_arg.train_image_dir_name, img_filename) + '.jpg'
|
||||
|
||||
img = Image.open(img_path)
|
||||
d_wight, d_height = resize_image(img, arg.max_predict_img_size)
|
||||
img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
|
||||
img = img.resize((img_w, img_h), Image.NEAREST).convert('RGB')
|
||||
img = np.asarray(img)
|
||||
img = img / 1.
|
||||
mean = np.array((123.68, 116.779, 103.939)).reshape([1, 1, 3])
|
||||
img = ((img - mean)).astype(np.float32)
|
||||
img = img.transpose((2, 0, 1))
|
||||
x[i - idx] = img
|
||||
# predict(east, img_path, threshold, cfg.pixel_threshold)
|
||||
gt_list.append(np.load(os.path.join(arg.train_label_dir_name, img_filename) + '.npy'))
|
||||
if idx + arg.batch_size >= len(f_list):
|
||||
|
||||
gt_list.append(np.load(os.path.join(eval_arg.train_label_dir_name, img_filename) + '.npy'))
|
||||
if idx + eval_arg.batch_size >= len(f_list):
|
||||
x = x[:len(f_list) - idx]
|
||||
y = net(Tensor(x))
|
||||
obj.add(y, gt_list)
|
||||
|
@ -129,12 +130,12 @@ def eval_score(arg):
|
|||
print(obj.val())
|
||||
|
||||
|
||||
def pred(arg):
|
||||
def pred(eval_arg):
|
||||
"""pred"""
|
||||
img_path = arg.path
|
||||
net = AdvancedEast()
|
||||
load_param_into_net(net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt)))
|
||||
predict(net, img_path, arg.pixel_threshold)
|
||||
img_path = eval_arg.path
|
||||
net = AdvancedEast(eval_arg)
|
||||
load_param_into_net(net, load_checkpoint(os.path.join(eval_arg.saved_model_file_path, eval_arg.ckpt)))
|
||||
predict(net, img_path, eval_arg.pixel_threshold)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -163,7 +164,6 @@ if __name__ == '__main__':
|
|||
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
print(os.path.join(args.data_dir, args.mindsrecord_test_file))
|
||||
dataset, batch_num = load_adEAST_dataset(os.path.join(args.data_dir,
|
||||
args.mindsrecord_test_file),
|
||||
batch_size=args.batch_size,
|
||||
|
|
|
@ -30,7 +30,7 @@ parser.add_argument('--width', type=int, default=256, help='input width')
|
|||
parser.add_argument('--height', type=int, default=256, help='input height')
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
choices=["Ascend", "GPU", "CPU"], help="device target(default: Ascend)")
|
||||
choices=["Ascend", "GPU"], help="device target(default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
|
|
@ -49,14 +49,6 @@ if __name__ == '__main__':
|
|||
process_label_size(384)
|
||||
preprocess_size(448)
|
||||
process_label_size(448)
|
||||
# preprocess_size(512)
|
||||
# process_label_size(512)
|
||||
# preprocess_size(640)
|
||||
# process_label_size(640)
|
||||
# preprocess_size(736)
|
||||
# process_label_size(736)
|
||||
# preprocess_size(768)
|
||||
# process_label_size(768)
|
||||
mindrecord_filename = os.path.join(cfg.data_dir, cfg.mindsrecord_train_file)
|
||||
transImage2Mind(mindrecord_filename)
|
||||
mindrecord_filename = os.path.join(cfg.data_dir, cfg.mindsrecord_test_file)
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
tqdm
|
||||
shapely
|
||||
opencv
|
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_distribute_train_gpu.sh"
|
||||
echo "for example: bash run_distribute_train_gpu.sh"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
mpirun -n 8 --output-filename log_output --merge-stderr-to-stdout --allow-run-as-root\
|
||||
python train.py \
|
||||
--device_target="GPU" \
|
||||
--is_distributed=1 > output.train.log 2>&1 &
|
|
@ -0,0 +1,40 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_eval.sh ckpt_filename eval_method pred_image_name(if eval_method is pred)"
|
||||
echo "for example: bash run_eval.sh 0_8-24_1012.ckpt pred ./demo/001.png"
|
||||
echo "for example: bash run_eval.sh 0_8-24_1012.ckpt score"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
CKPT=$1
|
||||
METHOD=$2
|
||||
|
||||
if [ $# == 3 ] && [ $METHOD == "pred" ]
|
||||
then
|
||||
PATH1=$3
|
||||
python eval.py \
|
||||
--device_target="Ascend" \
|
||||
--ckpt=$CKPT \
|
||||
--method=$METHOD \
|
||||
--path=$PATH1 > output.eval.log 2>&1 &
|
||||
else
|
||||
python eval.py \
|
||||
--device_target="Ascend" \
|
||||
--ckpt=$CKPT \
|
||||
--method=$METHOD > output.eval.log 2>&1 &
|
||||
fi
|
|
@ -0,0 +1,40 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_eval.sh ckpt_filename eval_method pred_image_name(if eval_method is pred)"
|
||||
echo "for example: bash run_eval.sh 0_8-24_1012.ckpt pred ./demo/001.png"
|
||||
echo "for example: bash run_eval.sh 0_8-24_1012.ckpt score"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
CKPT=$1
|
||||
METHOD=$2
|
||||
|
||||
if [ $# == 3 ] && [ $METHOD == "pred" ]
|
||||
then
|
||||
PATH1=$3
|
||||
python eval.py \
|
||||
--device_target="GPU" \
|
||||
--ckpt=$CKPT \
|
||||
--method=$METHOD \
|
||||
--path=$PATH1 > output.eval.log 2>&1 &
|
||||
else
|
||||
python eval.py \
|
||||
--device_target="GPU" \
|
||||
--ckpt=$CKPT \
|
||||
--method=$METHOD > output.eval.log 2>&1 &
|
||||
fi
|
|
@ -1,3 +1,4 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -12,3 +13,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_standalone_train_gpu.sh"
|
||||
echo "for example: bash run_standalone_train_gpu.sh"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
python train.py \
|
||||
--device_target="GPU" > output.train.log 2>&1 &
|
|
@ -18,10 +18,11 @@ configs
|
|||
from easydict import EasyDict
|
||||
|
||||
config = EasyDict({
|
||||
'initial_epoch': 2,
|
||||
'epoch_num': 6,
|
||||
'learning_rate': 1e-4,
|
||||
'decay': 3e-5,
|
||||
'learning_rate_ascend': 1e-4,
|
||||
'learning_rate_gpu': 1e-3,
|
||||
'decay_ascend': 3e-5,
|
||||
'decay_gpu': 5e-4,
|
||||
'epsilon': 1e-4,
|
||||
'batch_size': 2,
|
||||
'ckpt_interval': 2,
|
||||
|
@ -36,7 +37,6 @@ config = EasyDict({
|
|||
'validation_split_ratio': 0.1,
|
||||
'total_img': 10000,
|
||||
'data_dir': './icpr/',
|
||||
'val_data_dir': './icpr/images/',
|
||||
'train_fname': 'train.txt',
|
||||
'train_fname_var': 'train_',
|
||||
'val_fname': 'val.txt',
|
||||
|
@ -71,5 +71,33 @@ config = EasyDict({
|
|||
'gen_origin_img': True,
|
||||
'draw_gt_quad': False,
|
||||
'draw_act_quad': False,
|
||||
'vgg_npy': '/disk1/ade/vgg16.npy',
|
||||
'vgg_npy': './vgg16.npy',
|
||||
'vgg_weights': './src/0-150_5004.ckpt',
|
||||
'ds_sink_mode': False
|
||||
})
|
||||
|
||||
cifar_cfg = EasyDict({
|
||||
"num_classes": 10,
|
||||
"lr": 0.01,
|
||||
"lr_init": 0.01,
|
||||
"lr_max": 0.1,
|
||||
"lr_epochs": '30,60,90,120',
|
||||
"lr_scheduler": "step",
|
||||
"warmup_epochs": 5,
|
||||
"batch_size": 64,
|
||||
"max_epoch": 70,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 5e-4,
|
||||
"loss_scale": 1.0,
|
||||
"label_smooth": 0,
|
||||
"label_smooth_factor": 0,
|
||||
"buffer_size": 10,
|
||||
"image_size": '224,224',
|
||||
"pad_mode": 'same',
|
||||
"padding": 0,
|
||||
"has_bias": False,
|
||||
"batch_norm": True,
|
||||
"keep_checkpoint_max": 10,
|
||||
"initialize_mode": "XavierUniform",
|
||||
"has_dropout": False
|
||||
})
|
||||
|
|
|
@ -86,14 +86,14 @@ def shrink(xy_list, ratio=cfg.shrink_ratio):
|
|||
new_xy_list = np.copy(temp_new_xy_list)
|
||||
shrink_edge(temp_new_xy_list, new_xy_list, short_edge, r, theta, ratio)
|
||||
shrink_edge(temp_new_xy_list, new_xy_list, short_edge + 2, r, theta, ratio)
|
||||
return temp_new_xy_list, new_xy_list, long_edge # 缩短后的长边,缩短后的短边,长边下标
|
||||
return temp_new_xy_list, new_xy_list, long_edge
|
||||
|
||||
|
||||
def shrink_edge(xy_list, new_xy_list, edge, r, theta, ratio=cfg.shrink_ratio):
|
||||
"""shrink edge"""
|
||||
if ratio == 0.0:
|
||||
return
|
||||
start_point = edge # 边的起始点下标(0或1)
|
||||
start_point = edge
|
||||
end_point = (edge + 1) % 4
|
||||
long_start_sign_x = np.sign(
|
||||
xy_list[end_point, 0] - xy_list[start_point, 0])
|
||||
|
|
|
@ -17,89 +17,168 @@ dataset processing.
|
|||
"""
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, ParameterTuple, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from mindspore.ops import ResizeNearestNeighbor
|
||||
from mindspore import Tensor, ParameterTuple, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import numpy as np
|
||||
|
||||
from src.vgg import Vgg
|
||||
from src.config import config as cfg
|
||||
|
||||
|
||||
vgg_cfg = {
|
||||
'11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
||||
'19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
|
||||
}
|
||||
|
||||
|
||||
def vgg16(num_classes=1000, args=None, phase="train"):
|
||||
"""
|
||||
Get Vgg16 neural network with batch normalization.
|
||||
|
||||
Args:
|
||||
num_classes (int): Class numbers. Default: 1000.
|
||||
args(namespace): param for net init.
|
||||
phase(str): train or test mode.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of Vgg16 neural network with batch normalization.
|
||||
|
||||
Examples:
|
||||
>>> vgg16(num_classes=1000, args=args)
|
||||
"""
|
||||
|
||||
if args is None:
|
||||
from src.config import cifar_cfg
|
||||
args = cifar_cfg
|
||||
net = Vgg(vgg_cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase)
|
||||
return net
|
||||
|
||||
class AdvancedEast(nn.Cell):
|
||||
"""
|
||||
EAST network definition.
|
||||
East model
|
||||
Args:
|
||||
args
|
||||
"""
|
||||
def __init__(self):
|
||||
|
||||
def __init__(self, args):
|
||||
super(AdvancedEast, self).__init__()
|
||||
vgg_dict = np.load(cfg.vgg_npy, encoding='latin1', allow_pickle=True).item()
|
||||
self.device_target = args.device_target
|
||||
if self.device_target == 'GPU':
|
||||
|
||||
def get_var(name, idx):
|
||||
value = vgg_dict[name][idx]
|
||||
if idx == 0:
|
||||
value = np.transpose(value, [3, 2, 0, 1])
|
||||
var = Tensor(value)
|
||||
return var
|
||||
self.vgg16 = vgg16()
|
||||
param_dict = load_checkpoint(cfg.vgg_weights)
|
||||
load_param_into_net(self.vgg16, param_dict)
|
||||
|
||||
def get_conv_var(name):
|
||||
filters = get_var(name, 0)
|
||||
biases = get_var(name, 1)
|
||||
return filters, biases
|
||||
self.bn1 = nn.BatchNorm2d(1024, momentum=0.99, eps=1e-3)
|
||||
self.conv1 = nn.Conv2d(1024, 128, 1, weight_init='XavierUniform', has_bias=True)
|
||||
self.relu1 = nn.ReLU()
|
||||
|
||||
class VGG_Conv(nn.Cell):
|
||||
"""
|
||||
VGG16 network definition.
|
||||
"""
|
||||
def __init__(self, name):
|
||||
super(VGG_Conv, self).__init__()
|
||||
filters, conv_biases = get_conv_var(name)
|
||||
out_channels, in_channels, filter_size, _ = filters.shape
|
||||
self.conv2d = P.Conv2D(out_channels, filter_size, pad_mode='same', mode=1)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.weight = Parameter(initializer(filters, [out_channels, in_channels, filter_size, filter_size]),
|
||||
name='weight')
|
||||
self.bias = Parameter(initializer(conv_biases, [out_channels]), name='bias')
|
||||
self.relu = P.ReLU()
|
||||
self.gn = nn.GroupNorm(32, out_channels)
|
||||
self.bn2 = nn.BatchNorm2d(128, momentum=0.99, eps=1e-3)
|
||||
self.conv2 = nn.Conv2d(128, 128, 3, padding=1, pad_mode='pad', weight_init='XavierUniform')
|
||||
self.relu2 = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
output = self.conv2d(x, self.weight)
|
||||
output = self.bias_add(output, self.bias)
|
||||
output = self.gn(output)
|
||||
output = self.relu(output)
|
||||
return output
|
||||
self.bn3 = nn.BatchNorm2d(384, momentum=0.99, eps=1e-3)
|
||||
self.conv3 = nn.Conv2d(384, 64, 1, weight_init='XavierUniform', has_bias=True)
|
||||
self.relu3 = nn.ReLU()
|
||||
|
||||
self.conv1_1 = VGG_Conv('conv1_1')
|
||||
self.conv1_2 = VGG_Conv('conv1_2')
|
||||
self.pool1 = nn.MaxPool2d(2, 2)
|
||||
self.conv2_1 = VGG_Conv('conv2_1')
|
||||
self.conv2_2 = VGG_Conv('conv2_2')
|
||||
self.pool2 = nn.MaxPool2d(2, 2)
|
||||
self.conv3_1 = VGG_Conv('conv3_1')
|
||||
self.conv3_2 = VGG_Conv('conv3_2')
|
||||
self.conv3_3 = VGG_Conv('conv3_3')
|
||||
self.pool3 = nn.MaxPool2d(2, 2)
|
||||
self.conv4_1 = VGG_Conv('conv4_1')
|
||||
self.conv4_2 = VGG_Conv('conv4_2')
|
||||
self.conv4_3 = VGG_Conv('conv4_3')
|
||||
self.pool4 = nn.MaxPool2d(2, 2)
|
||||
self.conv5_1 = VGG_Conv('conv5_1')
|
||||
self.conv5_2 = VGG_Conv('conv5_2')
|
||||
self.conv5_3 = VGG_Conv('conv5_3')
|
||||
self.pool5 = nn.MaxPool2d(2, 2)
|
||||
self.merging1 = self.merging(i=2)
|
||||
self.merging2 = self.merging(i=3)
|
||||
self.merging3 = self.merging(i=4)
|
||||
self.last_bn = nn.GroupNorm(16, 32)
|
||||
self.conv_last = nn.Conv2d(32, 32, kernel_size=3, stride=1, has_bias=True, weight_init='XavierUniform')
|
||||
self.inside_score_conv = nn.Conv2d(32, 1, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform')
|
||||
self.side_v_angle_conv = nn.Conv2d(32, 2, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform')
|
||||
self.side_v_coord_conv = nn.Conv2d(32, 4, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform')
|
||||
self.op_concat = P.Concat(axis=1)
|
||||
self.relu = P.ReLU()
|
||||
self.bn4 = nn.BatchNorm2d(64, momentum=0.99, eps=1e-3)
|
||||
self.conv4 = nn.Conv2d(64, 64, 3, padding=1, pad_mode='pad', weight_init='XavierUniform')
|
||||
self.relu4 = nn.ReLU()
|
||||
|
||||
self.bn5 = nn.BatchNorm2d(192, momentum=0.99, eps=1e-3)
|
||||
self.conv5 = nn.Conv2d(192, 32, 1, weight_init='XavierUniform', has_bias=True)
|
||||
self.relu5 = nn.ReLU()
|
||||
|
||||
self.bn6 = nn.BatchNorm2d(32, momentum=0.99, eps=1e-3)
|
||||
self.conv6 = nn.Conv2d(32, 32, 3, padding=1, pad_mode='pad', weight_init='XavierUniform', has_bias=True)
|
||||
self.relu6 = nn.ReLU()
|
||||
|
||||
self.bn7 = nn.BatchNorm2d(32, momentum=0.99, eps=1e-3)
|
||||
self.conv7 = nn.Conv2d(32, 32, 3, padding=1, pad_mode='pad', weight_init='XavierUniform', has_bias=True)
|
||||
self.relu7 = nn.ReLU()
|
||||
|
||||
self.cat = P.Concat(axis=1)
|
||||
|
||||
self.conv8 = nn.Conv2d(32, 1, 1, weight_init='XavierUniform', has_bias=True)
|
||||
self.conv9 = nn.Conv2d(32, 2, 1, weight_init='XavierUniform', has_bias=True)
|
||||
self.conv10 = nn.Conv2d(32, 4, 1, weight_init='XavierUniform', has_bias=True)
|
||||
else:
|
||||
vgg_dict = np.load(cfg.vgg_npy, encoding='latin1', allow_pickle=True).item()
|
||||
|
||||
def get_var(name, idx):
|
||||
value = vgg_dict[name][idx]
|
||||
if idx == 0:
|
||||
value = np.transpose(value, [3, 2, 0, 1])
|
||||
var = Tensor(value)
|
||||
return var
|
||||
|
||||
def get_conv_var(name):
|
||||
filters = get_var(name, 0)
|
||||
biases = get_var(name, 1)
|
||||
return filters, biases
|
||||
|
||||
class VGG_Conv(nn.Cell):
|
||||
"""
|
||||
VGG16 network definition.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
super(VGG_Conv, self).__init__()
|
||||
filters, conv_biases = get_conv_var(name)
|
||||
out_channels, in_channels, filter_size, _ = filters.shape
|
||||
self.conv2d = P.Conv2D(out_channels, filter_size, pad_mode='same', mode=1)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.weight = Parameter(initializer(filters, [out_channels, in_channels, filter_size, filter_size]),
|
||||
name='weight')
|
||||
self.bias = Parameter(initializer(conv_biases, [out_channels]), name='bias')
|
||||
self.relu = P.ReLU()
|
||||
self.gn = nn.GroupNorm(32, out_channels)
|
||||
|
||||
def construct(self, x):
|
||||
output = self.conv2d(x, self.weight)
|
||||
output = self.bias_add(output, self.bias)
|
||||
output = self.gn(output)
|
||||
output = self.relu(output)
|
||||
return output
|
||||
|
||||
self.conv1_1 = VGG_Conv('conv1_1')
|
||||
self.conv1_2 = VGG_Conv('conv1_2')
|
||||
self.pool1 = nn.MaxPool2d(2, 2)
|
||||
self.conv2_1 = VGG_Conv('conv2_1')
|
||||
self.conv2_2 = VGG_Conv('conv2_2')
|
||||
self.pool2 = nn.MaxPool2d(2, 2)
|
||||
self.conv3_1 = VGG_Conv('conv3_1')
|
||||
self.conv3_2 = VGG_Conv('conv3_2')
|
||||
self.conv3_3 = VGG_Conv('conv3_3')
|
||||
self.pool3 = nn.MaxPool2d(2, 2)
|
||||
self.conv4_1 = VGG_Conv('conv4_1')
|
||||
self.conv4_2 = VGG_Conv('conv4_2')
|
||||
self.conv4_3 = VGG_Conv('conv4_3')
|
||||
self.pool4 = nn.MaxPool2d(2, 2)
|
||||
self.conv5_1 = VGG_Conv('conv5_1')
|
||||
self.conv5_2 = VGG_Conv('conv5_2')
|
||||
self.conv5_3 = VGG_Conv('conv5_3')
|
||||
self.pool5 = nn.MaxPool2d(2, 2)
|
||||
self.merging1 = self.merging(i=2)
|
||||
self.merging2 = self.merging(i=3)
|
||||
self.merging3 = self.merging(i=4)
|
||||
self.last_bn = nn.GroupNorm(16, 32)
|
||||
self.conv_last = nn.Conv2d(32, 32, kernel_size=3, stride=1, has_bias=True, weight_init='XavierUniform')
|
||||
self.inside_score_conv = nn.Conv2d(32, 1, kernel_size=1, stride=1, has_bias=True,
|
||||
weight_init='XavierUniform')
|
||||
self.side_v_angle_conv = nn.Conv2d(32, 2, kernel_size=1, stride=1, has_bias=True,
|
||||
weight_init='XavierUniform')
|
||||
self.side_v_coord_conv = nn.Conv2d(32, 4, kernel_size=1, stride=1, has_bias=True,
|
||||
weight_init='XavierUniform')
|
||||
self.op_concat = P.Concat(axis=1)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def merging(self, i=2):
|
||||
"""
|
||||
|
@ -119,141 +198,190 @@ class AdvancedEast(nn.Cell):
|
|||
|
||||
def construct(self, x):
|
||||
"""
|
||||
forward func
|
||||
forward func
|
||||
"""
|
||||
f4 = self.conv1_1(x)
|
||||
f4 = self.conv1_2(f4)
|
||||
f4 = self.pool1(f4)
|
||||
f4 = self.conv2_1(f4)
|
||||
f4 = self.conv2_2(f4)
|
||||
f4 = self.pool2(f4)
|
||||
f3 = self.conv3_1(f4)
|
||||
f3 = self.conv3_2(f3)
|
||||
f3 = self.conv3_3(f3)
|
||||
f3 = self.pool3(f3)
|
||||
f2 = self.conv4_1(f3)
|
||||
f2 = self.conv4_2(f2)
|
||||
f2 = self.conv4_3(f2)
|
||||
f2 = self.pool4(f2)
|
||||
f1 = self.conv5_1(f2)
|
||||
f1 = self.conv5_2(f1)
|
||||
f1 = self.conv5_3(f1)
|
||||
f1 = self.pool5(f1)
|
||||
h1 = f1
|
||||
_, _, h_, w_ = P.Shape()(h1)
|
||||
H1 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h1)
|
||||
concat1 = self.op_concat((H1, f2))
|
||||
h2 = self.merging1(concat1)
|
||||
_, _, h_, w_ = P.Shape()(h2)
|
||||
H2 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h2)
|
||||
concat2 = self.op_concat((H2, f3))
|
||||
h3 = self.merging2(concat2)
|
||||
_, _, h_, w_ = P.Shape()(h3)
|
||||
H3 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h3)
|
||||
concat3 = self.op_concat((H3, f4))
|
||||
h4 = self.merging3(concat3)
|
||||
before_output = self.relu(self.last_bn(self.conv_last(h4)))
|
||||
inside_score = self.inside_score_conv(before_output)
|
||||
side_v_angle = self.side_v_angle_conv(before_output)
|
||||
side_v_coord = self.side_v_coord_conv(before_output)
|
||||
east_detect = self.op_concat((inside_score, side_v_coord, side_v_angle))
|
||||
if self.device_target == 'GPU':
|
||||
l2, l3, l4, l5 = self.vgg16(x)
|
||||
h = l5
|
||||
|
||||
_, _, h_, w_ = P.Shape()(h)
|
||||
g = ResizeNearestNeighbor((h_ * 2, w_ * 2))(h)
|
||||
c = self.cat((g, l4))
|
||||
|
||||
c = self.bn1(c)
|
||||
c = self.conv1(c)
|
||||
c = self.relu1(c)
|
||||
|
||||
h = self.bn2(c)
|
||||
h = self.conv2(h)
|
||||
h = self.relu2(h)
|
||||
|
||||
_, _, h_, w_ = P.Shape()(h)
|
||||
g = ResizeNearestNeighbor((h_ * 2, w_ * 2))(h)
|
||||
c = self.cat((g, l3))
|
||||
|
||||
c = self.bn3(c)
|
||||
c = self.conv3(c)
|
||||
c = self.relu3(c)
|
||||
|
||||
h = self.bn4(c)
|
||||
h = self.conv4(h) # bs 64 w/8 h/8
|
||||
h = self.relu4(h)
|
||||
|
||||
_, _, h_, w_ = P.Shape()(h)
|
||||
g = ResizeNearestNeighbor((h_ * 2, w_ * 2))(h)
|
||||
c = self.cat((g, l2))
|
||||
|
||||
c = self.bn5(c)
|
||||
c = self.conv5(c)
|
||||
c = self.relu5(c)
|
||||
|
||||
h = self.bn6(c)
|
||||
h = self.conv6(h) # bs 32 w/4 h/4
|
||||
h = self.relu6(h)
|
||||
|
||||
g = self.bn7(h)
|
||||
g = self.conv7(g) # bs 32 w/4 h/4
|
||||
g = self.relu7(g)
|
||||
# get output
|
||||
|
||||
inside_score = self.conv8(g)
|
||||
side_v_code = self.conv9(g)
|
||||
side_v_coord = self.conv10(g)
|
||||
east_detect = self.cat((inside_score, side_v_code, side_v_coord))
|
||||
else:
|
||||
f4 = self.conv1_1(x)
|
||||
f4 = self.conv1_2(f4)
|
||||
f4 = self.pool1(f4)
|
||||
f4 = self.conv2_1(f4)
|
||||
f4 = self.conv2_2(f4)
|
||||
f4 = self.pool2(f4)
|
||||
f3 = self.conv3_1(f4)
|
||||
f3 = self.conv3_2(f3)
|
||||
f3 = self.conv3_3(f3)
|
||||
f3 = self.pool3(f3)
|
||||
f2 = self.conv4_1(f3)
|
||||
f2 = self.conv4_2(f2)
|
||||
f2 = self.conv4_3(f2)
|
||||
f2 = self.pool4(f2)
|
||||
f1 = self.conv5_1(f2)
|
||||
f1 = self.conv5_2(f1)
|
||||
f1 = self.conv5_3(f1)
|
||||
f1 = self.pool5(f1)
|
||||
h1 = f1
|
||||
_, _, h_, w_ = P.Shape()(h1)
|
||||
H1 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h1)
|
||||
concat1 = self.op_concat((H1, f2))
|
||||
h2 = self.merging1(concat1)
|
||||
_, _, h_, w_ = P.Shape()(h2)
|
||||
H2 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h2)
|
||||
concat2 = self.op_concat((H2, f3))
|
||||
h3 = self.merging2(concat2)
|
||||
_, _, h_, w_ = P.Shape()(h3)
|
||||
H3 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h3)
|
||||
concat3 = self.op_concat((H3, f4))
|
||||
h4 = self.merging3(concat3)
|
||||
before_output = self.relu(self.last_bn(self.conv_last(h4)))
|
||||
inside_score = self.inside_score_conv(before_output)
|
||||
side_v_angle = self.side_v_angle_conv(before_output)
|
||||
side_v_coord = self.side_v_coord_conv(before_output)
|
||||
east_detect = self.op_concat((inside_score, side_v_coord, side_v_angle))
|
||||
|
||||
return east_detect
|
||||
|
||||
|
||||
def dice_loss(gt_score, pred_score):
|
||||
"""dice_loss1"""
|
||||
inter = P.ReduceSum()(gt_score * pred_score)
|
||||
union = P.ReduceSum()(gt_score) + P.ReduceSum()(pred_score) + 1e-5
|
||||
return 1. - (2 * (inter / union))
|
||||
|
||||
|
||||
def dice_loss2(gt_score, pred_score, mask):
|
||||
"""dice_loss2"""
|
||||
inter = P.ReduceSum()(gt_score * pred_score * mask)
|
||||
union = P.ReduceSum()(gt_score * mask) + P.ReduceSum()(pred_score * mask) + 1e-5
|
||||
return 1. - (2 * (inter / union))
|
||||
|
||||
|
||||
def quad_loss(y_true, y_pred,
|
||||
lambda_inside_score_loss=0.2,
|
||||
lambda_side_vertex_code_loss=0.1,
|
||||
lambda_side_vertex_coord_loss=1.0,
|
||||
epsilon=1e-4):
|
||||
"""quad loss"""
|
||||
y_true = P.Transpose()(y_true, (0, 2, 3, 1))
|
||||
y_pred = P.Transpose()(y_pred, (0, 2, 3, 1))
|
||||
logits = y_pred[:, :, :, :1]
|
||||
labels = y_true[:, :, :, :1]
|
||||
predicts = P.Sigmoid()(logits)
|
||||
inside_score_loss = dice_loss(labels, predicts)
|
||||
inside_score_loss = inside_score_loss * lambda_inside_score_loss
|
||||
# loss for side_vertex_code
|
||||
vertex_logitsp = P.Sigmoid()(y_pred[:, :, :, 1:2])
|
||||
vertex_labelsp = y_true[:, :, :, 1:2]
|
||||
vertex_logitsn = P.Sigmoid()(y_pred[:, :, :, 2:3])
|
||||
vertex_labelsn = y_true[:, :, :, 2:3]
|
||||
labels2 = y_true[:, :, :, 1:2]
|
||||
side_vertex_code_lossp = dice_loss2(vertex_labelsp, vertex_logitsp, labels)
|
||||
side_vertex_code_lossn = dice_loss2(vertex_labelsn, vertex_logitsn, labels2)
|
||||
side_vertex_code_loss = (side_vertex_code_lossp + side_vertex_code_lossn) * lambda_side_vertex_code_loss
|
||||
# loss for side_vertex_coord delta
|
||||
g_hat = y_pred[:, :, :, 3:] # N*W*H*8
|
||||
g_true = y_true[:, :, :, 3:]
|
||||
vertex_weights = P.Cast()(P.Equal()(y_true[:, :, :, 1], 1), mindspore.float32)
|
||||
# vertex_weights=y_true[:, :, :,1]
|
||||
pixel_wise_smooth_l1norm = smooth_l1_loss(g_hat, g_true, vertex_weights)
|
||||
side_vertex_coord_loss = P.ReduceSum()(pixel_wise_smooth_l1norm) / (
|
||||
P.ReduceSum()(vertex_weights) + epsilon)
|
||||
side_vertex_coord_loss = side_vertex_coord_loss * lambda_side_vertex_coord_loss
|
||||
# print(inside_score_loss, side_vertex_code_loss, side_vertex_coord_loss)
|
||||
return inside_score_loss + side_vertex_code_loss + side_vertex_coord_loss
|
||||
|
||||
|
||||
def smooth_l1_loss(prediction_tensor, target_tensor, weights):
|
||||
"""smooth l1 loss"""
|
||||
n_q = P.Reshape()(quad_norm(target_tensor), weights.shape)
|
||||
diff = P.SmoothL1Loss()(prediction_tensor, target_tensor)
|
||||
pixel_wise_smooth_l1norm = P.ReduceSum()(diff, -1) / n_q * weights
|
||||
return pixel_wise_smooth_l1norm
|
||||
|
||||
|
||||
def quad_norm(g_true, epsilon=1e-4):
|
||||
""" quad norm"""
|
||||
shape = g_true.shape
|
||||
delta_xy_matrix = P.Reshape()(g_true, (shape[0] * shape[1] * shape[2], 2, 2))
|
||||
diff = delta_xy_matrix[:, 0:1, :] - delta_xy_matrix[:, 1:2, :]
|
||||
square = diff * diff
|
||||
distance = P.Sqrt()(P.ReduceSum()(square, -1))
|
||||
distance = distance * 4.0
|
||||
distance = distance + epsilon
|
||||
return P.Reshape()(distance, (shape[0], shape[1], shape[2]))
|
||||
|
||||
|
||||
class EastWithLossCell(nn.Cell):
|
||||
"""get loss cell"""
|
||||
"""
|
||||
loss
|
||||
"""
|
||||
|
||||
def __init__(self, network, config=None):
|
||||
def __init__(self, network):
|
||||
super(EastWithLossCell, self).__init__()
|
||||
self.East_network = network
|
||||
self.cat = P.Concat(axis=1)
|
||||
|
||||
def dice_loss(self, gt_score, pred_score):
|
||||
"""dice_loss1"""
|
||||
inter = P.ReduceSum()(gt_score * pred_score)
|
||||
union = P.ReduceSum()(gt_score) + P.ReduceSum()(pred_score) + 1e-5
|
||||
return 1. - (2 * (inter / union))
|
||||
|
||||
def dice_loss2(self, gt_score, pred_score, mask):
|
||||
"""dice_loss2"""
|
||||
inter = P.ReduceSum()(gt_score * pred_score * mask)
|
||||
union = P.ReduceSum()(gt_score * mask) + P.ReduceSum()(pred_score * mask) + 1e-5
|
||||
return 1. - (2 * (inter / union))
|
||||
|
||||
def quad_loss(self, y_true, y_pred,
|
||||
lambda_inside_score_loss=0.2,
|
||||
lambda_side_vertex_code_loss=0.1,
|
||||
lambda_side_vertex_coord_loss=1.0,
|
||||
epsilon=1e-4):
|
||||
"""quad loss"""
|
||||
y_true = P.Transpose()(y_true, (0, 2, 3, 1))
|
||||
y_pred = P.Transpose()(y_pred, (0, 2, 3, 1))
|
||||
logits = y_pred[:, :, :, :1]
|
||||
labels = y_true[:, :, :, :1]
|
||||
predicts = P.Sigmoid()(logits)
|
||||
inside_score_loss = self.dice_loss(labels, predicts)
|
||||
inside_score_loss = inside_score_loss * lambda_inside_score_loss
|
||||
# loss for side_vertex_code
|
||||
vertex_logitsp = P.Sigmoid()(y_pred[:, :, :, 1:2])
|
||||
vertex_labelsp = y_true[:, :, :, 1:2]
|
||||
vertex_logitsn = P.Sigmoid()(y_pred[:, :, :, 2:3])
|
||||
vertex_labelsn = y_true[:, :, :, 2:3]
|
||||
labels2 = y_true[:, :, :, 1:2]
|
||||
side_vertex_code_lossp = self.dice_loss2(vertex_labelsp, vertex_logitsp, labels)
|
||||
side_vertex_code_lossn = self.dice_loss2(vertex_labelsn, vertex_logitsn, labels2)
|
||||
side_vertex_code_loss = (side_vertex_code_lossp + side_vertex_code_lossn) * lambda_side_vertex_code_loss
|
||||
# loss for side_vertex_coord delta
|
||||
g_hat = y_pred[:, :, :, 3:] # N*W*H*8
|
||||
g_true = y_true[:, :, :, 3:]
|
||||
vertex_weights = P.Cast()(P.Equal()(y_true[:, :, :, 1], 1), mindspore.float32)
|
||||
|
||||
pixel_wise_smooth_l1norm = self.smooth_l1_loss(g_hat, g_true, vertex_weights)
|
||||
side_vertex_coord_loss = P.ReduceSum()(pixel_wise_smooth_l1norm) / (
|
||||
P.ReduceSum()(vertex_weights) + epsilon)
|
||||
side_vertex_coord_loss = side_vertex_coord_loss * lambda_side_vertex_coord_loss
|
||||
return inside_score_loss + side_vertex_code_loss + side_vertex_coord_loss
|
||||
|
||||
def smooth_l1_loss(self, prediction_tensor, target_tensor, weights):
|
||||
"""smooth l1 loss"""
|
||||
n_q = P.Reshape()(self.quad_norm(target_tensor), weights.shape)
|
||||
diff = P.SmoothL1Loss()(prediction_tensor, target_tensor)
|
||||
pixel_wise_smooth_l1norm = P.ReduceSum()(diff, -1) / n_q * weights
|
||||
return pixel_wise_smooth_l1norm
|
||||
|
||||
def quad_norm(self, g_true, epsilon=1e-4):
|
||||
""" quad norm"""
|
||||
shape = g_true.shape
|
||||
delta_xy_matrix = P.Reshape()(g_true, (shape[0] * shape[1] * shape[2], 2, 2))
|
||||
diff = delta_xy_matrix[:, 0:1, :] - delta_xy_matrix[:, 1:2, :]
|
||||
square = diff * diff
|
||||
distance = P.Sqrt()(P.ReduceSum()(square, -1))
|
||||
distance = distance * 4.0
|
||||
distance = distance + epsilon
|
||||
return P.Reshape()(distance, (shape[0], shape[1], shape[2]))
|
||||
|
||||
def construct(self, image, label):
|
||||
y_pred = self.East_network(image)
|
||||
loss = quad_loss(label, y_pred)
|
||||
loss = self.quad_loss(label, y_pred)
|
||||
return loss
|
||||
|
||||
|
||||
class TrainStepWrap(nn.Cell):
|
||||
"""train net warper"""
|
||||
def __init__(self, network, steps, config=None):
|
||||
"""
|
||||
train net
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(TrainStepWrap, self).__init__()
|
||||
self.network = network
|
||||
self.network.set_train()
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = AdamWeightDecay(self.weights, learning_rate=cfg.learning_rate
|
||||
, eps=1e-7, weight_decay=cfg.decay)
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = 3.0
|
||||
self.sens = 1.0
|
||||
|
||||
def construct(self, image, label):
|
||||
weights = self.weights
|
||||
|
@ -263,11 +391,11 @@ class TrainStepWrap(nn.Cell):
|
|||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
def get_AdvancedEast_net(configure=None, steps=1, mode=True):
|
||||
def get_AdvancedEast_net(args):
|
||||
"""
|
||||
Get network of wide&deep model.
|
||||
"""
|
||||
AdvancedEast_net = AdvancedEast()
|
||||
loss_net = EastWithLossCell(AdvancedEast_net, configure)
|
||||
train_net = TrainStepWrap(loss_net, steps, configure)
|
||||
AdvancedEast_net = AdvancedEast(args)
|
||||
loss_net = EastWithLossCell(AdvancedEast_net)
|
||||
train_net = TrainStepWrap(loss_net)
|
||||
return loss_net, train_net
|
||||
|
|
|
@ -72,7 +72,6 @@ def reorder_vertexes(xy_list):
|
|||
b_mid = xy_list[first_v, 1] - k[k_mid] * xy_list[first_v, 0]
|
||||
second_v, fourth_v = 0, 0
|
||||
for index, i in zip(others, range(len(others))):
|
||||
# delta = y - (k * x + b)
|
||||
delta_y = xy_list[index, 1] - (k[k_mid] * xy_list[index, 0] + b_mid)
|
||||
if delta_y > 0:
|
||||
second_v = index
|
||||
|
@ -140,7 +139,7 @@ def preprocess():
|
|||
if os.path.exists(os.path.join(train_image_dir, o_img_fname)) and \
|
||||
os.path.exists(os.path.join(show_gt_image_dir, o_img_fname)):
|
||||
continue
|
||||
# d_wight, d_height = resize_image(im)
|
||||
|
||||
d_wight, d_height = cfg.max_train_img_size, cfg.max_train_img_size
|
||||
scale_ratio_w = d_wight / im.width
|
||||
scale_ratio_h = d_height / im.height
|
||||
|
@ -242,7 +241,7 @@ def preprocess_size(width=256):
|
|||
if os.path.exists(os.path.join(train_image_dir, o_img_fname)) and \
|
||||
os.path.exists(os.path.join(show_gt_image_dir, o_img_fname)):
|
||||
continue
|
||||
# d_wight, d_height = resize_image(im)
|
||||
|
||||
d_wight, d_height = width, width
|
||||
scale_ratio_w = d_wight / im.width
|
||||
scale_ratio_h = d_height / im.height
|
||||
|
@ -303,7 +302,6 @@ def preprocess_size(width=256):
|
|||
train_label_list = os.listdir(train_label_dir)
|
||||
print('found %d train labels.' % len(train_label_list))
|
||||
|
||||
# random.shuffle(train_val_set)
|
||||
val_count = int(cfg.validation_split_ratio * len(train_val_set))
|
||||
with open(os.path.join(data_dir, cfg.val_fname_var + str(width) + '.txt'), 'a') as f_val:
|
||||
f_val.writelines(train_val_set[:val_count])
|
||||
|
|
|
@ -21,9 +21,6 @@ from shapely.geometry import Polygon
|
|||
from src.config import config as cfg
|
||||
from src.nms import nms
|
||||
|
||||
|
||||
# from src.preprocess import resize_image
|
||||
|
||||
class Averager():
|
||||
"""Compute average for torch.Tensor, used for loss average."""
|
||||
|
||||
|
@ -99,12 +96,10 @@ class eval_pre_rec_f1():
|
|||
return 0, 0, 0
|
||||
quad_flag = np.zeros(num_quads)
|
||||
gt_flag = np.zeros(num_gts)
|
||||
# print(num_quads, '-------', num_gts)
|
||||
quad_scores_no_zero = np.array(quad_scores_no_zero)
|
||||
scores_idx = np.argsort(quad_scores_no_zero)[::-1]
|
||||
for i in range(num_quads):
|
||||
idx = scores_idx[i]
|
||||
# score = quad_scores_no_zero[idx]
|
||||
geo = quad_after_nms_no_zero[idx]
|
||||
for j in range(num_gts):
|
||||
if gt_flag[j] == 0:
|
||||
|
@ -122,14 +117,12 @@ class eval_pre_rec_f1():
|
|||
f1_score = 0
|
||||
else:
|
||||
f1_score = 2 * pre * rec / (pre + rec)
|
||||
# print(pre, '---', rec, '---', f1_score)
|
||||
return pre, rec, f1_score
|
||||
|
||||
def add(self, out, gt_xy_list):
|
||||
"""`add`"""
|
||||
self.img_num += len(gt_xy_list)
|
||||
ys = out.asnumpy() # (N, 7, 64, 64)
|
||||
print(ys.shape)
|
||||
if ys.shape[1] == 7:
|
||||
ys = ys.transpose((0, 2, 3, 1)) # NCHW->NHWC
|
||||
for y, gt_xy in zip(ys, gt_xy_list):
|
||||
|
@ -137,7 +130,6 @@ class eval_pre_rec_f1():
|
|||
cond = np.greater_equal(y[:, :, 0], self.pixel_threshold)
|
||||
activation_pixels = np.where(cond)
|
||||
quad_scores, quad_after_nms = nms(y, activation_pixels)
|
||||
print(quad_scores)
|
||||
lens = len(quad_after_nms)
|
||||
if lens == 0 or (sum(sum(quad_scores)) == 0):
|
||||
if not cfg.quiet:
|
||||
|
@ -145,7 +137,6 @@ class eval_pre_rec_f1():
|
|||
continue
|
||||
else:
|
||||
pre, rec, f1_score = self.eval_one(quad_scores, quad_after_nms, gt_xy)
|
||||
print(pre, rec, f1_score)
|
||||
self.pre += pre
|
||||
self.rec += rec
|
||||
self.f1_score += f1_score
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Image classifiation.
|
||||
"""
|
||||
import math
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common import initializer as init
|
||||
from mindspore.common.initializer import initializer, HeNormal
|
||||
|
||||
from src.config import config
|
||||
npy = config['vgg_npy']
|
||||
|
||||
def _make_layer(base, args, batch_norm):
|
||||
"""Make stage network of VGG."""
|
||||
layers = []
|
||||
in_channels = 3
|
||||
layer = []
|
||||
for v in base:
|
||||
if v == 'M':
|
||||
layer += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
layers.append(layer)
|
||||
layer = []
|
||||
else:
|
||||
weight = 'ones'
|
||||
if args.initialize_mode == "XavierUniform":
|
||||
weight_shape = (v, in_channels, 3, 3)
|
||||
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()
|
||||
|
||||
conv2d = nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=v,
|
||||
kernel_size=3,
|
||||
padding=args.padding,
|
||||
pad_mode=args.pad_mode,
|
||||
has_bias=args.has_bias,
|
||||
weight_init=weight)
|
||||
if batch_norm:
|
||||
layer += [conv2d, nn.BatchNorm2d(v), nn.ReLU()]
|
||||
else:
|
||||
layer += [conv2d, nn.ReLU()]
|
||||
in_channels = v
|
||||
layer1 = nn.SequentialCell(layers[0])
|
||||
layer2 = nn.SequentialCell(layers[1])
|
||||
layer3 = nn.SequentialCell(layers[2])
|
||||
layer4 = nn.SequentialCell(layers[3])
|
||||
layer5 = nn.SequentialCell(layers[4])
|
||||
|
||||
return layer1, layer2, layer3, layer4, layer5
|
||||
|
||||
|
||||
class Vgg(nn.Cell):
|
||||
"""
|
||||
VGG network definition.
|
||||
|
||||
Args:
|
||||
base (list): Configuration for different layers, mainly the channel number of Conv layer.
|
||||
num_classes (int): Class numbers. Default: 1000.
|
||||
batch_norm (bool): Whether to do the batchnorm. Default: False.
|
||||
batch_size (int): Batch size. Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor, infer output tensor.
|
||||
|
||||
Examples:
|
||||
>>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
||||
>>> num_classes=1000, batch_norm=False, batch_size=1)
|
||||
"""
|
||||
|
||||
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train"):
|
||||
super(Vgg, self).__init__()
|
||||
_ = batch_size
|
||||
self.layer1, self.layer2, self.layer3, self.layer4, self.layer5 = _make_layer(base, args, batch_norm=batch_norm)
|
||||
if args.initialize_mode == "KaimingNormal":
|
||||
#default_recurisive_init(self)
|
||||
self.custom_init_weight()
|
||||
|
||||
def construct(self, x):
|
||||
l1 = self.layer1(x)
|
||||
l2 = self.layer2(l1)
|
||||
l3 = self.layer3(l2)
|
||||
l4 = self.layer4(l3)
|
||||
l5 = self.layer5(l4)
|
||||
return l2, l3, l4, l5
|
||||
|
||||
def custom_init_weight(self):
|
||||
"""
|
||||
Init the weight of Conv2d and Dense in the net.
|
||||
"""
|
||||
for _, cell in self.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.set_data(init.initializer(
|
||||
HeNormal(negative_slope=math.sqrt(5), mode='fan_out', nonlinearity='relu'),
|
||||
cell.weight.shape, cell.weight.dtype))
|
||||
if cell.bias is not None:
|
||||
cell.bias.set_data(init.initializer(
|
||||
'zeros', cell.bias.shape, cell.bias.dtype))
|
||||
elif isinstance(cell, nn.Dense):
|
||||
cell.weight.set_data(init.initializer(
|
||||
init.Normal(0.01), cell.weight.shape, cell.weight.dtype))
|
||||
if cell.bias is not None:
|
||||
cell.bias.set_data(init.initializer(
|
||||
'zeros', cell.bias.shape, cell.bias.dtype))
|
||||
|
||||
|
||||
cfg = {
|
||||
'11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
||||
'19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
|
||||
}
|
||||
|
||||
|
||||
def vgg16(num_classes=1000, args=None, phase="train"):
|
||||
"""
|
||||
Get Vgg16 neural network with batch normalization.
|
||||
|
||||
Args:
|
||||
num_classes (int): Class numbers. Default: 1000.
|
||||
args(namespace): param for net init.
|
||||
phase(str): train or test mode.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of Vgg16 neural network with batch normalization.
|
||||
|
||||
Examples:
|
||||
>>> vgg16(num_classes=1000, args=args)
|
||||
"""
|
||||
|
||||
if args is None:
|
||||
from src.config import cifar_cfg
|
||||
args = cifar_cfg
|
||||
net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase)
|
||||
return net
|
|
@ -18,9 +18,11 @@
|
|||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
import ast
|
||||
|
||||
from mindspore import context, Model
|
||||
from mindspore.communication.management import init, get_group_size
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_param_into_net, load_checkpoint
|
||||
|
@ -34,15 +36,16 @@ from src.config import config as cfg
|
|||
set_seed(1)
|
||||
|
||||
|
||||
def parse_args(cloud_args=None):
|
||||
def parse_args():
|
||||
"""parameters"""
|
||||
parser = argparse.ArgumentParser('mindspore adveast training')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend.')
|
||||
|
||||
# network related
|
||||
parser.add_argument('--pre_trained', default=False, type=bool, help='model_path, local pretrained model to load')
|
||||
parser.add_argument('--pre_trained', default=False, type=ast.literal_eval,
|
||||
help='model_path, local pretrained model to load')
|
||||
|
||||
# logging and checkpoint related
|
||||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
|
||||
|
@ -55,27 +58,21 @@ def parse_args(cloud_args=None):
|
|||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
args_opt.initial_epoch = cfg.initial_epoch
|
||||
args_opt.epoch_num = cfg.epoch_num
|
||||
args_opt.learning_rate = cfg.learning_rate
|
||||
args_opt.decay = cfg.decay
|
||||
args_opt.batch_size = cfg.batch_size
|
||||
args_opt.total_train_img = cfg.total_img * (1 - cfg.validation_split_ratio)
|
||||
args_opt.total_valid_img = cfg.total_img * cfg.validation_split_ratio
|
||||
args_opt.ckpt_save_max = cfg.ckpt_save_max
|
||||
args_opt.data_dir = cfg.data_dir
|
||||
args_opt.mindsrecord_train_file = cfg.mindsrecord_train_file
|
||||
args_opt.mindsrecord_test_file = cfg.mindsrecord_test_file
|
||||
args_opt.train_image_dir_name = cfg.train_image_dir_name
|
||||
args_opt.train_label_dir_name = cfg.train_label_dir_name
|
||||
args_opt.results_dir = cfg.results_dir
|
||||
args_opt.last_model_name = cfg.last_model_name
|
||||
args_opt.saved_model_file_path = cfg.saved_model_file_path
|
||||
args_opt.ds_sink_mode = cfg.ds_sink_mode
|
||||
return args_opt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
workers = 32
|
||||
|
@ -86,6 +83,7 @@ if __name__ == '__main__':
|
|||
elif args.device_target == "GPU":
|
||||
context.set_context(device_target=args.device_target)
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
|
||||
args.group_size = get_group_size()
|
||||
device_num = args.group_size
|
||||
|
@ -114,7 +112,7 @@ if __name__ == '__main__':
|
|||
args.rank_save_ckpt_flag = 1
|
||||
|
||||
# get network and init
|
||||
loss_net, train_net = get_AdvancedEast_net()
|
||||
loss_net, train_net = get_AdvancedEast_net(args)
|
||||
loss_net.add_flags_recursive(fp32=True)
|
||||
train_net.set_train(False)
|
||||
# pre_trained
|
||||
|
@ -136,53 +134,81 @@ if __name__ == '__main__':
|
|||
train_dataset448, batch_num448 = load_adEAST_dataset(mindrecordfile448, batch_size=2,
|
||||
device_num=device_num, rank_id=args.rank, is_training=True,
|
||||
num_parallel_workers=workers)
|
||||
|
||||
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||
, eps=1e-7, weight_decay=cfg.decay)
|
||||
start = time.time()
|
||||
learning_rate = cfg.learning_rate_ascend if args.device_target == 'Ascend' else cfg.learning_rate_gpu
|
||||
decay = cfg.decay_ascend if args.device_target == 'Ascend' else cfg.decay_gpu
|
||||
# train model using the images resized to 256
|
||||
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=learning_rate
|
||||
, eps=1e-7, weight_decay=decay)
|
||||
model = Model(train_net)
|
||||
time_cb = TimeMonitor(data_size=batch_num256)
|
||||
loss_cb = LossMonitor(per_print_times=batch_num256)
|
||||
callbacks = [time_cb, loss_cb]
|
||||
if args.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num256,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
save_ckpt_path = args.saved_model_file_path
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='Epoch_A{}'.format(args.rank))
|
||||
callbacks.append(ckpt_cb)
|
||||
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset256, callbacks=callbacks, dataset_sink_mode=False)
|
||||
model.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||
, eps=1e-7, weight_decay=cfg.decay)
|
||||
callbacks = []
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num256,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
save_ckpt_path = args.saved_model_file_path
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='Epoch_A{}'.format(args.rank))
|
||||
if args.is_distributed & args.is_save_on_master:
|
||||
if args.rank == 0:
|
||||
callbacks.extend([time_cb, loss_cb, ckpt_cb])
|
||||
model.train(args.epoch_num, train_dataset=train_dataset256,
|
||||
callbacks=callbacks, dataset_sink_mode=args.ds_sink_mode)
|
||||
else:
|
||||
callbacks.extend([time_cb, loss_cb, ckpt_cb])
|
||||
model.train(args.epoch_num, train_dataset=train_dataset256,
|
||||
callbacks=callbacks, dataset_sink_mode=args.ds_sink_mode)
|
||||
print(time.time() - start)
|
||||
# train model using the images resized to 384
|
||||
model.optimizer = AdamWeightDecay(train_net.weights, learning_rate=learning_rate
|
||||
, eps=1e-7, weight_decay=decay)
|
||||
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=learning_rate
|
||||
, eps=1e-7, weight_decay=decay)
|
||||
model = Model(train_net)
|
||||
|
||||
time_cb = TimeMonitor(data_size=batch_num384)
|
||||
loss_cb = LossMonitor(per_print_times=batch_num384)
|
||||
callbacks = [time_cb, loss_cb]
|
||||
if args.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num384,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
save_ckpt_path = args.saved_model_file_path
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='Epoch_B{}'.format(args.rank))
|
||||
callbacks.append(ckpt_cb)
|
||||
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||
, eps=1e-7, weight_decay=cfg.decay)
|
||||
callbacks = []
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num384,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='Epoch_B{}'.format(args.rank))
|
||||
if args.is_distributed & args.is_save_on_master:
|
||||
if args.rank == 0:
|
||||
callbacks.extend([time_cb, loss_cb, ckpt_cb])
|
||||
model.train(args.epoch_num, train_dataset=train_dataset384,
|
||||
callbacks=callbacks, dataset_sink_mode=args.ds_sink_mode)
|
||||
else:
|
||||
callbacks.extend([time_cb, loss_cb, ckpt_cb])
|
||||
model.train(args.epoch_num, train_dataset=train_dataset384,
|
||||
callbacks=callbacks, dataset_sink_mode=args.ds_sink_mode)
|
||||
print(time.time() - start)
|
||||
|
||||
# train model using the images resized to 448
|
||||
model.optimizer = AdamWeightDecay(train_net.weights, learning_rate=learning_rate
|
||||
, eps=1e-7, weight_decay=decay)
|
||||
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=learning_rate
|
||||
, eps=1e-7, weight_decay=decay)
|
||||
model = Model(train_net)
|
||||
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset384, callbacks=callbacks, dataset_sink_mode=False)
|
||||
model.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||
, eps=1e-7, weight_decay=cfg.decay)
|
||||
|
||||
time_cb = TimeMonitor(data_size=batch_num448)
|
||||
loss_cb = LossMonitor(per_print_times=batch_num448)
|
||||
callbacks = [time_cb, loss_cb]
|
||||
if args.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num448,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
save_ckpt_path = args.saved_model_file_path
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='Epoch_C{}'.format(args.rank))
|
||||
callbacks.append(ckpt_cb)
|
||||
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||
, eps=1e-7, weight_decay=cfg.decay)
|
||||
model = Model(train_net)
|
||||
model.train(epoch=cfg.epoch_num, train_dataset=train_dataset448, callbacks=callbacks, dataset_sink_mode=False)
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num448,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
callbacks = []
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='Epoch_C{}'.format(args.rank))
|
||||
if args.is_distributed & args.is_save_on_master:
|
||||
if args.rank == 0:
|
||||
callbacks.extend([time_cb, loss_cb, ckpt_cb])
|
||||
model.train(args.epoch_num, train_dataset=train_dataset448,
|
||||
callbacks=callbacks, dataset_sink_mode=args.ds_sink_mode)
|
||||
else:
|
||||
callbacks.extend([time_cb, loss_cb, ckpt_cb])
|
||||
model.train(args.epoch_num, train_dataset=train_dataset448,
|
||||
callbacks=callbacks, dataset_sink_mode=args.ds_sink_mode)
|
||||
|
||||
print(time.time() - start)
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import ast
|
||||
|
||||
from mindspore import context, Model
|
||||
from mindspore.common import set_seed
|
||||
|
@ -40,12 +41,13 @@ def parse_args(cloud_args=None):
|
|||
parser = argparse.ArgumentParser('mindspore adveast training')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend.')
|
||||
|
||||
# network related
|
||||
parser.add_argument('--pre_trained', default=False, type=bool, help='model_path, local pretrained model to load')
|
||||
parser.add_argument('--data_path', default='/disk1/ade/icpr/advanced-east-val_256.mindrecord', type=str)
|
||||
parser.add_argument('--pre_trained_ckpt', default='/disk1/adeast/scripts/1.ckpt', type=str)
|
||||
parser.add_argument('--pre_trained', default=False, type=ast.literal_eval,
|
||||
help='model_path, local pretrained model to load')
|
||||
parser.add_argument('--data_path', type=str)
|
||||
parser.add_argument('--pre_trained_ckpt', type=str)
|
||||
|
||||
# logging and checkpoint related
|
||||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
|
||||
|
@ -58,19 +60,12 @@ def parse_args(cloud_args=None):
|
|||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
args_opt.initial_epoch = cfg.initial_epoch
|
||||
args_opt.epoch_num = cfg.epoch_num
|
||||
args_opt.learning_rate = cfg.learning_rate
|
||||
args_opt.decay = cfg.decay
|
||||
args_opt.batch_size = cfg.batch_size
|
||||
args_opt.total_train_img = cfg.total_img * (1 - cfg.validation_split_ratio)
|
||||
args_opt.total_valid_img = cfg.total_img * cfg.validation_split_ratio
|
||||
args_opt.ckpt_save_max = cfg.ckpt_save_max
|
||||
args_opt.data_dir = cfg.data_dir
|
||||
args_opt.mindsrecord_train_file = cfg.mindsrecord_train_file
|
||||
args_opt.mindsrecord_test_file = cfg.mindsrecord_test_file
|
||||
args_opt.train_image_dir_name = cfg.train_image_dir_name
|
||||
args_opt.train_label_dir_name = cfg.train_label_dir_name
|
||||
args_opt.results_dir = cfg.results_dir
|
||||
args_opt.last_model_name = cfg.last_model_name
|
||||
args_opt.saved_model_file_path = cfg.saved_model_file_path
|
||||
|
@ -123,16 +118,18 @@ if __name__ == '__main__':
|
|||
args.rank_save_ckpt_flag = 1
|
||||
|
||||
# get network and init
|
||||
loss_net, train_net = get_AdvancedEast_net()
|
||||
loss_net, train_net = get_AdvancedEast_net(args)
|
||||
loss_net.add_flags_recursive(fp32=True)
|
||||
train_net.set_train(True)
|
||||
# pre_trained
|
||||
if args.pre_trained:
|
||||
load_param_into_net(train_net, load_checkpoint(args.pre_trained_ckpt))
|
||||
# define callbacks
|
||||
|
||||
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate
|
||||
, eps=1e-7, weight_decay=cfg.decay)
|
||||
learning_rate = cfg.learning_rate_ascend if args.device_target == 'Ascend' else cfg.learning_rate_gpu
|
||||
decay = cfg.decay_ascend if args.device_target == 'Ascend' else cfg.decay_gpu
|
||||
|
||||
train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=learning_rate
|
||||
, eps=1e-7, weight_decay=decay)
|
||||
model = Model(train_net)
|
||||
time_cb = TimeMonitor(data_size=batch_num)
|
||||
loss_cb = LossMonitor(per_print_times=batch_num)
|
||||
|
|
Loading…
Reference in New Issue