diff --git a/model_zoo/research/cv/advanced_east/README.md b/model_zoo/research/cv/advanced_east/README.md index ad41900bb77..6967b3fb56b 100644 --- a/model_zoo/research/cv/advanced_east/README.md +++ b/model_zoo/research/cv/advanced_east/README.md @@ -1,42 +1,56 @@ -# AdvancedEAST +# Contents -## introduction +- [Advanced East Description](#advancedeast-description) +- [Environment](#environment) +- [Dependences](#dependences) +- [Project Files](#project-files) +- [Dataset](#dataset) +- [Run The Project](#run-the-project) + - [Data Preprocess](#data-preprocess) + - [Training Process](#training-process) + - [Evaluation Process](#evaluation-process) +- [Performance](#performance) + - [Training Performance](#training-performance) + - [Evaluation Performance](#evaluation-performance) +- [ModelZoo Homepage](#modelzoo-homepage) + +# [Advanced East Description](#contents) AdvancedEAST is inspired by EAST [EAST:An Efficient and Accurate Scene Text Detector](https://arxiv.org/abs/1704.03155v2). -The architecture of AdvancedEAST is showed below![AdvancedEast network arch](AdvancedEast.network.png). +The architecture of AdvancedEAST is showed below.![AdvancedEast network arch](AdvancedEast.network.png) This project is inherited by [huoyijie/AdvancedEAST](https://github.com/huoyijie/AdvancedEAST)(preprocess, network architecture, predict) and [BaoWentz/AdvancedEAST-PyTorch](https://github.com/BaoWentz/AdvancedEAST-PyTorch)(performance). -## Environment +# [Environment](#contents) -* euleros v2r7 x86_64 -* python 3.7.5 +- euleros v2r7 x86_64\Ubuntu 16.04 +- python 3.7.5 -## Dependences +# [Dependences](#contents) -* mindspore==1.1 -* shapely==1.7.1 -* numpy==1.19.4 -* tqdm==4.36.1 +- mindspore==1.2.0 +- shapely==1.7.1 +- numpy==1.19.4 +- tqdm==4.36.1 -## Project files +# [Project Files](#contents) -* configuration of file +- configuration of file cfg.py, control parameters -* pre-process data: +- pre-process data: preprocess.py, resize image -* label data: +- label data: label.py,produce label info -* define network +- define network model.py and VGG.py -* define loss function +- define loss function losses.py -* execute training +- execute training advanced_east.py and dataset.py -* predict +- predict predict.py and nms.py -* scoring +- scoring score.py -* logging +- logging logger.py ```shell @@ -44,10 +58,11 @@ This project is inherited by [huoyijie/AdvancedEAST](https://github.com/huoyijie └──advanced_east ├── README.md ├── scripts - ├── run_distribute_train_gpu.sh # launch ascend distributed training(8 pcs) + ├── run_distribute_train_ascend.sh # launch ascend distributed training(8 pcs) ├── run_standalone_train_ascend.sh # launch ascend standalone training(1 pcs) ├── run_distribute_train_gpu.sh # launch gpu distributed training(8 pcs) └── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs) + └── eval.sh # evaluate model(1 pcs) ├── src ├── cfg.py # parameter configuration ├── dataset.py # data preprocessing @@ -58,6 +73,7 @@ This project is inherited by [huoyijie/AdvancedEAST](https://github.com/huoyijie ├── predict.py # predict boxes ├── preprocess.py # pre-process data └── score.py # scoring + └── vgg.py # vgg model ├── export.py # export model for inference ├── prepare_data.py # exec data preprocessing ├── eval.py # eval net @@ -65,7 +81,7 @@ This project is inherited by [huoyijie/AdvancedEAST](https://github.com/huoyijie └── train_mindrecord.py # train net on user specified mindrecord ``` -## dataset +# [Dataset](#contents) ICPR MTWI 2018 challenge 2:Text detection of network image,[Link](https://tianchi.aliyun.com/competition/entrance/231651/introduction). It is not available to download dataset on the origin webpage, the dataset is now provided by the author of the original project,[Baiduyun link](https://pan.baidu.com/s/1NSyc-cHKV3IwDo6qojIrKA), password: ye9y. There are 10000 images and corresponding label @@ -80,7 +96,7 @@ of 9:1. If you want to use your own dataset, please modify the configuration of > ``` Some parameters in config.py: -```python +```shell 'validation_split_ratio': 0.1, # ratio of validation dataset 'total_img': 10000, # total number of samples in dataset 'data_dir': './icpr/', # dir of dataset @@ -94,9 +110,9 @@ Some parameters in config.py: 'train_label_dir_name': 'labels_train/', # dir which stores the preprocessed text verteices. ``` -## Run the project +# [Run The Project](#contents) -### Data preprocess +## [Data Preprocess](#contents) Resize all the images to fixed size, and convert the label information(the vertex of text box) into the format used in training and evaluation, then the Mindsrecord files are generated. @@ -104,23 +120,30 @@ Resize all the images to fixed size, and convert the label information(the verte python preparedata.py ``` -### Training +## [Training Process](#contents) Prepare the VGG16 pre-training model. Due to copyright restrictions, please go to https://github.com/machrisaa/tensorflow-vgg to download the VGG16 pre-training model and place it in the src folder. +If you have the checkpoint of VGG16, you can load the parameters in this way, the training training time can be shorten obviously. -single 1p) +- single Ascend ```bash python train.py --device_target="Ascend" --is_distributed=0 --device_id=0 > output.train.log 2>&1 & ``` -single 1p)specific size +- single GPU + +```bash +python train.py --device_target="GPU" --is_distributed=0 --device_id=0 > output.train.log 2>&1 & +``` + +- single device with specific size ```bash python train_mindrecord.py --device_target="Ascend" --is_distributed=0 --device_id=2 --size=256 > output.train.log 2>&1 & ``` -multi Ascends +- multi Ascends ```bash # running on distributed environment(8p) @@ -129,7 +152,7 @@ bash scripts/run_distribute_train.sh The detailed training parameters are in /src/config.py。 -multi GPUs +- multi GPUs ```bash # running on distributed environment(8p) @@ -153,13 +176,34 @@ config.py: 'max_predict_img_size': 448, # max size of the images to predict 'ckpt_save_max': 10, # maximum of ckpt in dir 'saved_model_file_path': './saved_model/', # dir of saved model + 'norm': 'BN', # normalization in feature merging branch ``` -## Evaluation +## [Evaluation Process](#contents) -### Evaluate +The above python command will run in the background, you can view the results through the file output.eval.log. You will get the accuracy as following. +You can get loss, accuracy, recall, F1 score and the box vertices of an image. -The above python command will run in the background, you can view the results through the file output.eval.log. You will get the accuracy as following: +- loss + +```bash +# evaluate loss of the model +bash scripts/run_distribute_train_gpu.sh +``` + +- score + +```bash +# evaluate loss of the model +bash scripts/run_distribute_train_gpu.sh +``` + +- prediction + +```bash +# get prediction of an image +bash run_eval.sh 0_8-24_1012.ckpt pred ./demo/001.png +``` ## Inference Process @@ -191,32 +235,37 @@ Inference result is saved in current path, you can find result in acc.log file. ## performance -### Training performance +## [Training Performance](#contents) -The performance listed below are acquired with the default configurations in /src/config.py -| Parameters | single Ascend | -| -------------------------- | ---------------------------------------------- | -| Model Version | AdvancedEAST | -| Resources | Ascend 910 | -| MindSpore Version | 1.1 | -| Dataset | MTWI-2018 | -| Training Parameters | epoch=18, batch_size = 8, lr=1e-3 | -| Optimizer | AdamWeightDecay | -| Loss Function | QuadLoss | -| Outputs | matrix with size of 3x64x64,3x96x96,3x112x112 | -| Loss | 0.1 | -| Total Time | 28mins, 60mins, 90mins | -| Checkpoints | 173MB(.ckpt file) | +The performance listed below are acquired with the default configurations in /src/config.py. +The Normalization of model training on Ascend is GN, the model training on GPU is used BN. +| Parameters | single Ascend | 8 GPUs | +| -------------------- | ------------------------------------- |---------------------------------------------- | +| Model Version | AdvancedEAST | AdvancedEAST | +| Resources | Ascend 910 | Tesla V100S-PCIE 32G| +| MindSpore Version | 1.1 |1.1 | +| Dataset | MTWI-2018 |MTWI-2018 | +| Training Parameters | epoch=18, batch_size = 8, lr=1e-3 |epoch=84, batch_size = 8, lr=1e-3 | +| Optimizer | AdamWeightDecay |AdamWeightDecay | +| Loss Function | QuadLoss |QuadLoss | +| Outputs | matrix with size of 3x64x64,3x96x96,3x112x112 |matrix with size of 3x64x64,3x96x96,3x112x112 | +| Loss | 0.1 |0.1 | +| Total Time | 28 mins, 60 mins, 90 mins | 4.9 mins, 10.3 mins, 14.5 mins +| Checkpoints | 173MB(.ckpt file) |173MB(.ckpt file) | -### Evaluation Performance +## [Evaluation Performance](#contents) On the default -| Parameters | single Ascend | -| ------------------- | --------------------------- | -| Model Version | AdvancedEAST | -| Resources | Ascend 910 | -| MindSpore Version | 1.1 | -| Dataset | 1000 images | -| batch_size | 8 | -| Outputs | precision, recall, F score | -| performance | 94.35, 55.45, 66.31 | \ No newline at end of file +| Parameters | single Ascend | 8 GPUs | +| ------------------- | --------------------------- |--------------------------- | +| Model Version | AdvancedEAST |AdvancedEAST | +| Resources | Ascend 910 |Tesla V100S-PCIE 32G| +| MindSpore Version | 1.1 | 1.1 | +| Dataset | 1000 images |1000 images | +| batch_size | 8 | 8 | +| Outputs | precision, recall, F score |precision, recall, F score | +| performance | 94.35, 55.45, 66.31 | 92.53 55.49 66.01 | + +# [ModelZoo Homepage](#contents) + +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). \ No newline at end of file diff --git a/model_zoo/research/cv/advanced_east/eval.py b/model_zoo/research/cv/advanced_east/eval.py index 1be5e06e489..bcaaef51922 100644 --- a/model_zoo/research/cv/advanced_east/eval.py +++ b/model_zoo/research/cv/advanced_east/eval.py @@ -21,25 +21,26 @@ import os import numpy as np from PIL import Image +from tqdm import tqdm from mindspore import Tensor from mindspore import context from mindspore.common import set_seed from mindspore.communication.management import init, get_rank, get_group_size from mindspore.context import ParallelMode from mindspore.train.serialization import load_param_into_net, load_checkpoint + from src.logger import get_logger from src.predict import predict from src.score import eval_pre_rec_f1 - from src.config import config as cfg from src.dataset import load_adEAST_dataset from src.model import get_AdvancedEast_net, AdvancedEast -from src.preprocess import resize_image + set_seed(1) -def parse_args(cloud_args=None): +def parse_args(): """parameters""" parser = argparse.ArgumentParser('adveast evaling') parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], @@ -78,11 +79,11 @@ def parse_args(cloud_args=None): return args_opt -def eval_loss(arg): +def eval_loss(eval_arg): """get network and init""" - loss_net, train_net = get_AdvancedEast_net() - print(os.path.join(arg.saved_model_file_path, arg.ckpt)) - load_param_into_net(train_net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt))) + loss_net, train_net = get_AdvancedEast_net(eval_arg) + print(os.path.join(eval_arg.saved_model_file_path, eval_arg.ckpt)) + load_param_into_net(train_net, load_checkpoint(os.path.join(eval_arg.saved_model_file_path, eval_arg.ckpt))) train_net.set_train(False) loss = 0 idx = 0 @@ -92,36 +93,36 @@ def eval_loss(arg): print(loss / idx) -def eval_score(arg): +def eval_score(eval_arg): """get network and init""" - net = AdvancedEast() - load_param_into_net(net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt))) + net = AdvancedEast(eval_arg) + load_param_into_net(net, load_checkpoint(os.path.join(eval_arg.saved_model_file_path, eval_arg.ckpt))) net.set_train(False) obj = eval_pre_rec_f1() - with open(os.path.join(arg.data_dir, arg.val_fname), 'r') as f_val: + with open(os.path.join(eval_arg.data_dir, eval_arg.val_fname), 'r') as f_val: f_list = f_val.readlines() - img_h, img_w = arg.max_predict_img_size, arg.max_predict_img_size - x = np.zeros((arg.batch_size, 3, img_h, img_w), dtype=np.float32) - batch_list = np.arange(0, len(f_list), arg.batch_size) - for idx in batch_list: + img_h, img_w = eval_arg.max_predict_img_size, eval_arg.max_predict_img_size + x = np.zeros((eval_arg.batch_size, 3, img_h, img_w), dtype=np.float32) + batch_list = np.arange(0, len(f_list), eval_arg.batch_size) + for idx in tqdm(batch_list): gt_list = [] - for i in range(idx, min(idx + arg.batch_size, len(f_list))): + for i in range(idx, min(idx + eval_arg.batch_size, len(f_list))): item = f_list[i] img_filename = str(item).strip().split(',')[0][:-4] - img_path = os.path.join(arg.train_image_dir_name, img_filename) + '.jpg' + img_path = os.path.join(eval_arg.train_image_dir_name, img_filename) + '.jpg' + img = Image.open(img_path) - d_wight, d_height = resize_image(img, arg.max_predict_img_size) - img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB') + img = img.resize((img_w, img_h), Image.NEAREST).convert('RGB') img = np.asarray(img) img = img / 1. mean = np.array((123.68, 116.779, 103.939)).reshape([1, 1, 3]) img = ((img - mean)).astype(np.float32) img = img.transpose((2, 0, 1)) x[i - idx] = img - # predict(east, img_path, threshold, cfg.pixel_threshold) - gt_list.append(np.load(os.path.join(arg.train_label_dir_name, img_filename) + '.npy')) - if idx + arg.batch_size >= len(f_list): + + gt_list.append(np.load(os.path.join(eval_arg.train_label_dir_name, img_filename) + '.npy')) + if idx + eval_arg.batch_size >= len(f_list): x = x[:len(f_list) - idx] y = net(Tensor(x)) obj.add(y, gt_list) @@ -129,12 +130,12 @@ def eval_score(arg): print(obj.val()) -def pred(arg): +def pred(eval_arg): """pred""" - img_path = arg.path - net = AdvancedEast() - load_param_into_net(net, load_checkpoint(os.path.join(arg.saved_model_file_path, arg.ckpt))) - predict(net, img_path, arg.pixel_threshold) + img_path = eval_arg.path + net = AdvancedEast(eval_arg) + load_param_into_net(net, load_checkpoint(os.path.join(eval_arg.saved_model_file_path, eval_arg.ckpt))) + predict(net, img_path, eval_arg.pixel_threshold) if __name__ == '__main__': @@ -163,7 +164,6 @@ if __name__ == '__main__': args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) args.logger = get_logger(args.outputs_dir, args.rank) - print(os.path.join(args.data_dir, args.mindsrecord_test_file)) dataset, batch_num = load_adEAST_dataset(os.path.join(args.data_dir, args.mindsrecord_test_file), batch_size=args.batch_size, diff --git a/model_zoo/research/cv/advanced_east/export.py b/model_zoo/research/cv/advanced_east/export.py index 697b93fcdac..a1e23a2d43c 100644 --- a/model_zoo/research/cv/advanced_east/export.py +++ b/model_zoo/research/cv/advanced_east/export.py @@ -30,7 +30,7 @@ parser.add_argument('--width', type=int, default=256, help='input width') parser.add_argument('--height', type=int, default=256, help='input height') parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") parser.add_argument("--device_target", type=str, default="Ascend", - choices=["Ascend", "GPU", "CPU"], help="device target(default: Ascend)") + choices=["Ascend", "GPU"], help="device target(default: Ascend)") args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) diff --git a/model_zoo/research/cv/advanced_east/prepare_data.py b/model_zoo/research/cv/advanced_east/prepare_data.py index fc7008d6a3b..c6afeb90756 100644 --- a/model_zoo/research/cv/advanced_east/prepare_data.py +++ b/model_zoo/research/cv/advanced_east/prepare_data.py @@ -49,14 +49,6 @@ if __name__ == '__main__': process_label_size(384) preprocess_size(448) process_label_size(448) - # preprocess_size(512) - # process_label_size(512) - # preprocess_size(640) - # process_label_size(640) - # preprocess_size(736) - # process_label_size(736) - # preprocess_size(768) - # process_label_size(768) mindrecord_filename = os.path.join(cfg.data_dir, cfg.mindsrecord_train_file) transImage2Mind(mindrecord_filename) mindrecord_filename = os.path.join(cfg.data_dir, cfg.mindsrecord_test_file) diff --git a/model_zoo/research/cv/advanced_east/requirement.txt.txt b/model_zoo/research/cv/advanced_east/requirement.txt.txt new file mode 100644 index 00000000000..f2584917a0c --- /dev/null +++ b/model_zoo/research/cv/advanced_east/requirement.txt.txt @@ -0,0 +1,3 @@ +tqdm +shapely +opencv \ No newline at end of file diff --git a/model_zoo/research/cv/advanced_east/scripts/run_distribute_train.sh b/model_zoo/research/cv/advanced_east/scripts/run_distribute_train_ascend.sh similarity index 100% rename from model_zoo/research/cv/advanced_east/scripts/run_distribute_train.sh rename to model_zoo/research/cv/advanced_east/scripts/run_distribute_train_ascend.sh diff --git a/model_zoo/research/cv/advanced_east/scripts/run_distribute_train_gpu.sh b/model_zoo/research/cv/advanced_east/scripts/run_distribute_train_gpu.sh new file mode 100644 index 00000000000..e91f510a5e1 --- /dev/null +++ b/model_zoo/research/cv/advanced_east/scripts/run_distribute_train_gpu.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_distribute_train_gpu.sh" +echo "for example: bash run_distribute_train_gpu.sh" +echo "==============================================================================================================" + + mpirun -n 8 --output-filename log_output --merge-stderr-to-stdout --allow-run-as-root\ + python train.py \ + --device_target="GPU" \ + --is_distributed=1 > output.train.log 2>&1 & diff --git a/model_zoo/research/cv/advanced_east/scripts/run_eval_ascend.sh b/model_zoo/research/cv/advanced_east/scripts/run_eval_ascend.sh new file mode 100644 index 00000000000..418bec40382 --- /dev/null +++ b/model_zoo/research/cv/advanced_east/scripts/run_eval_ascend.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_eval.sh ckpt_filename eval_method pred_image_name(if eval_method is pred)" +echo "for example: bash run_eval.sh 0_8-24_1012.ckpt pred ./demo/001.png" +echo "for example: bash run_eval.sh 0_8-24_1012.ckpt score" +echo "==============================================================================================================" + +CKPT=$1 +METHOD=$2 + +if [ $# == 3 ] && [ $METHOD == "pred" ] +then + PATH1=$3 + python eval.py \ + --device_target="Ascend" \ + --ckpt=$CKPT \ + --method=$METHOD \ + --path=$PATH1 > output.eval.log 2>&1 & +else + python eval.py \ + --device_target="Ascend" \ + --ckpt=$CKPT \ + --method=$METHOD > output.eval.log 2>&1 & +fi diff --git a/model_zoo/research/cv/advanced_east/scripts/run_eval_gpu.sh b/model_zoo/research/cv/advanced_east/scripts/run_eval_gpu.sh new file mode 100644 index 00000000000..62a8052eff5 --- /dev/null +++ b/model_zoo/research/cv/advanced_east/scripts/run_eval_gpu.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_eval.sh ckpt_filename eval_method pred_image_name(if eval_method is pred)" +echo "for example: bash run_eval.sh 0_8-24_1012.ckpt pred ./demo/001.png" +echo "for example: bash run_eval.sh 0_8-24_1012.ckpt score" +echo "==============================================================================================================" + +CKPT=$1 +METHOD=$2 + +if [ $# == 3 ] && [ $METHOD == "pred" ] +then + PATH1=$3 + python eval.py \ + --device_target="GPU" \ + --ckpt=$CKPT \ + --method=$METHOD \ + --path=$PATH1 > output.eval.log 2>&1 & +else + python eval.py \ + --device_target="GPU" \ + --ckpt=$CKPT \ + --method=$METHOD > output.eval.log 2>&1 & +fi diff --git a/model_zoo/research/cv/advanced_east/src/__init__.py b/model_zoo/research/cv/advanced_east/scripts/run_standalone_train_gpu.sh similarity index 59% rename from model_zoo/research/cv/advanced_east/src/__init__.py rename to model_zoo/research/cv/advanced_east/scripts/run_standalone_train_gpu.sh index 6228b713269..e5e04d7575c 100644 --- a/model_zoo/research/cv/advanced_east/src/__init__.py +++ b/model_zoo/research/cv/advanced_east/scripts/run_standalone_train_gpu.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,3 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ + +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_standalone_train_gpu.sh" +echo "for example: bash run_standalone_train_gpu.sh" +echo "==============================================================================================================" + + python train.py \ + --device_target="GPU" > output.train.log 2>&1 & diff --git a/model_zoo/research/cv/advanced_east/src/config.py b/model_zoo/research/cv/advanced_east/src/config.py index 826b13b69ca..2061257ed64 100644 --- a/model_zoo/research/cv/advanced_east/src/config.py +++ b/model_zoo/research/cv/advanced_east/src/config.py @@ -18,10 +18,11 @@ configs from easydict import EasyDict config = EasyDict({ - 'initial_epoch': 2, 'epoch_num': 6, - 'learning_rate': 1e-4, - 'decay': 3e-5, + 'learning_rate_ascend': 1e-4, + 'learning_rate_gpu': 1e-3, + 'decay_ascend': 3e-5, + 'decay_gpu': 5e-4, 'epsilon': 1e-4, 'batch_size': 2, 'ckpt_interval': 2, @@ -36,7 +37,6 @@ config = EasyDict({ 'validation_split_ratio': 0.1, 'total_img': 10000, 'data_dir': './icpr/', - 'val_data_dir': './icpr/images/', 'train_fname': 'train.txt', 'train_fname_var': 'train_', 'val_fname': 'val.txt', @@ -71,5 +71,33 @@ config = EasyDict({ 'gen_origin_img': True, 'draw_gt_quad': False, 'draw_act_quad': False, - 'vgg_npy': '/disk1/ade/vgg16.npy', + 'vgg_npy': './vgg16.npy', + 'vgg_weights': './src/0-150_5004.ckpt', + 'ds_sink_mode': False +}) + +cifar_cfg = EasyDict({ + "num_classes": 10, + "lr": 0.01, + "lr_init": 0.01, + "lr_max": 0.1, + "lr_epochs": '30,60,90,120', + "lr_scheduler": "step", + "warmup_epochs": 5, + "batch_size": 64, + "max_epoch": 70, + "momentum": 0.9, + "weight_decay": 5e-4, + "loss_scale": 1.0, + "label_smooth": 0, + "label_smooth_factor": 0, + "buffer_size": 10, + "image_size": '224,224', + "pad_mode": 'same', + "padding": 0, + "has_bias": False, + "batch_norm": True, + "keep_checkpoint_max": 10, + "initialize_mode": "XavierUniform", + "has_dropout": False }) diff --git a/model_zoo/research/cv/advanced_east/src/label.py b/model_zoo/research/cv/advanced_east/src/label.py index 604315894f1..6da51ebab72 100644 --- a/model_zoo/research/cv/advanced_east/src/label.py +++ b/model_zoo/research/cv/advanced_east/src/label.py @@ -86,14 +86,14 @@ def shrink(xy_list, ratio=cfg.shrink_ratio): new_xy_list = np.copy(temp_new_xy_list) shrink_edge(temp_new_xy_list, new_xy_list, short_edge, r, theta, ratio) shrink_edge(temp_new_xy_list, new_xy_list, short_edge + 2, r, theta, ratio) - return temp_new_xy_list, new_xy_list, long_edge # 缩短后的长边,缩短后的短边,长边下标 + return temp_new_xy_list, new_xy_list, long_edge def shrink_edge(xy_list, new_xy_list, edge, r, theta, ratio=cfg.shrink_ratio): """shrink edge""" if ratio == 0.0: return - start_point = edge # 边的起始点下标(0或1) + start_point = edge end_point = (edge + 1) % 4 long_start_sign_x = np.sign( xy_list[end_point, 0] - xy_list[start_point, 0]) diff --git a/model_zoo/research/cv/advanced_east/src/model.py b/model_zoo/research/cv/advanced_east/src/model.py index 497ca277eec..b6a9c18c3c6 100644 --- a/model_zoo/research/cv/advanced_east/src/model.py +++ b/model_zoo/research/cv/advanced_east/src/model.py @@ -17,89 +17,168 @@ dataset processing. """ import mindspore import mindspore.nn as nn -from mindspore import Tensor, ParameterTuple, Parameter -from mindspore.common.initializer import initializer -from mindspore.nn.optim import AdamWeightDecay +from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.ops import functional as F -from mindspore.ops import operations as P - +from mindspore.ops import ResizeNearestNeighbor +from mindspore import Tensor, ParameterTuple, Parameter +from mindspore.common.initializer import initializer +from mindspore.train.serialization import load_checkpoint, load_param_into_net import numpy as np +from src.vgg import Vgg from src.config import config as cfg +vgg_cfg = { + '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +def vgg16(num_classes=1000, args=None, phase="train"): + """ + Get Vgg16 neural network with batch normalization. + + Args: + num_classes (int): Class numbers. Default: 1000. + args(namespace): param for net init. + phase(str): train or test mode. + + Returns: + Cell, cell instance of Vgg16 neural network with batch normalization. + + Examples: + >>> vgg16(num_classes=1000, args=args) + """ + + if args is None: + from src.config import cifar_cfg + args = cifar_cfg + net = Vgg(vgg_cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase) + return net + class AdvancedEast(nn.Cell): """ - EAST network definition. + East model + Args: + args """ - def __init__(self): + + def __init__(self, args): super(AdvancedEast, self).__init__() - vgg_dict = np.load(cfg.vgg_npy, encoding='latin1', allow_pickle=True).item() + self.device_target = args.device_target + if self.device_target == 'GPU': - def get_var(name, idx): - value = vgg_dict[name][idx] - if idx == 0: - value = np.transpose(value, [3, 2, 0, 1]) - var = Tensor(value) - return var + self.vgg16 = vgg16() + param_dict = load_checkpoint(cfg.vgg_weights) + load_param_into_net(self.vgg16, param_dict) - def get_conv_var(name): - filters = get_var(name, 0) - biases = get_var(name, 1) - return filters, biases + self.bn1 = nn.BatchNorm2d(1024, momentum=0.99, eps=1e-3) + self.conv1 = nn.Conv2d(1024, 128, 1, weight_init='XavierUniform', has_bias=True) + self.relu1 = nn.ReLU() - class VGG_Conv(nn.Cell): - """ - VGG16 network definition. - """ - def __init__(self, name): - super(VGG_Conv, self).__init__() - filters, conv_biases = get_conv_var(name) - out_channels, in_channels, filter_size, _ = filters.shape - self.conv2d = P.Conv2D(out_channels, filter_size, pad_mode='same', mode=1) - self.bias_add = P.BiasAdd() - self.weight = Parameter(initializer(filters, [out_channels, in_channels, filter_size, filter_size]), - name='weight') - self.bias = Parameter(initializer(conv_biases, [out_channels]), name='bias') - self.relu = P.ReLU() - self.gn = nn.GroupNorm(32, out_channels) + self.bn2 = nn.BatchNorm2d(128, momentum=0.99, eps=1e-3) + self.conv2 = nn.Conv2d(128, 128, 3, padding=1, pad_mode='pad', weight_init='XavierUniform') + self.relu2 = nn.ReLU() - def construct(self, x): - output = self.conv2d(x, self.weight) - output = self.bias_add(output, self.bias) - output = self.gn(output) - output = self.relu(output) - return output + self.bn3 = nn.BatchNorm2d(384, momentum=0.99, eps=1e-3) + self.conv3 = nn.Conv2d(384, 64, 1, weight_init='XavierUniform', has_bias=True) + self.relu3 = nn.ReLU() - self.conv1_1 = VGG_Conv('conv1_1') - self.conv1_2 = VGG_Conv('conv1_2') - self.pool1 = nn.MaxPool2d(2, 2) - self.conv2_1 = VGG_Conv('conv2_1') - self.conv2_2 = VGG_Conv('conv2_2') - self.pool2 = nn.MaxPool2d(2, 2) - self.conv3_1 = VGG_Conv('conv3_1') - self.conv3_2 = VGG_Conv('conv3_2') - self.conv3_3 = VGG_Conv('conv3_3') - self.pool3 = nn.MaxPool2d(2, 2) - self.conv4_1 = VGG_Conv('conv4_1') - self.conv4_2 = VGG_Conv('conv4_2') - self.conv4_3 = VGG_Conv('conv4_3') - self.pool4 = nn.MaxPool2d(2, 2) - self.conv5_1 = VGG_Conv('conv5_1') - self.conv5_2 = VGG_Conv('conv5_2') - self.conv5_3 = VGG_Conv('conv5_3') - self.pool5 = nn.MaxPool2d(2, 2) - self.merging1 = self.merging(i=2) - self.merging2 = self.merging(i=3) - self.merging3 = self.merging(i=4) - self.last_bn = nn.GroupNorm(16, 32) - self.conv_last = nn.Conv2d(32, 32, kernel_size=3, stride=1, has_bias=True, weight_init='XavierUniform') - self.inside_score_conv = nn.Conv2d(32, 1, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform') - self.side_v_angle_conv = nn.Conv2d(32, 2, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform') - self.side_v_coord_conv = nn.Conv2d(32, 4, kernel_size=1, stride=1, has_bias=True, weight_init='XavierUniform') - self.op_concat = P.Concat(axis=1) - self.relu = P.ReLU() + self.bn4 = nn.BatchNorm2d(64, momentum=0.99, eps=1e-3) + self.conv4 = nn.Conv2d(64, 64, 3, padding=1, pad_mode='pad', weight_init='XavierUniform') + self.relu4 = nn.ReLU() + + self.bn5 = nn.BatchNorm2d(192, momentum=0.99, eps=1e-3) + self.conv5 = nn.Conv2d(192, 32, 1, weight_init='XavierUniform', has_bias=True) + self.relu5 = nn.ReLU() + + self.bn6 = nn.BatchNorm2d(32, momentum=0.99, eps=1e-3) + self.conv6 = nn.Conv2d(32, 32, 3, padding=1, pad_mode='pad', weight_init='XavierUniform', has_bias=True) + self.relu6 = nn.ReLU() + + self.bn7 = nn.BatchNorm2d(32, momentum=0.99, eps=1e-3) + self.conv7 = nn.Conv2d(32, 32, 3, padding=1, pad_mode='pad', weight_init='XavierUniform', has_bias=True) + self.relu7 = nn.ReLU() + + self.cat = P.Concat(axis=1) + + self.conv8 = nn.Conv2d(32, 1, 1, weight_init='XavierUniform', has_bias=True) + self.conv9 = nn.Conv2d(32, 2, 1, weight_init='XavierUniform', has_bias=True) + self.conv10 = nn.Conv2d(32, 4, 1, weight_init='XavierUniform', has_bias=True) + else: + vgg_dict = np.load(cfg.vgg_npy, encoding='latin1', allow_pickle=True).item() + + def get_var(name, idx): + value = vgg_dict[name][idx] + if idx == 0: + value = np.transpose(value, [3, 2, 0, 1]) + var = Tensor(value) + return var + + def get_conv_var(name): + filters = get_var(name, 0) + biases = get_var(name, 1) + return filters, biases + + class VGG_Conv(nn.Cell): + """ + VGG16 network definition. + """ + + def __init__(self, name): + super(VGG_Conv, self).__init__() + filters, conv_biases = get_conv_var(name) + out_channels, in_channels, filter_size, _ = filters.shape + self.conv2d = P.Conv2D(out_channels, filter_size, pad_mode='same', mode=1) + self.bias_add = P.BiasAdd() + self.weight = Parameter(initializer(filters, [out_channels, in_channels, filter_size, filter_size]), + name='weight') + self.bias = Parameter(initializer(conv_biases, [out_channels]), name='bias') + self.relu = P.ReLU() + self.gn = nn.GroupNorm(32, out_channels) + + def construct(self, x): + output = self.conv2d(x, self.weight) + output = self.bias_add(output, self.bias) + output = self.gn(output) + output = self.relu(output) + return output + + self.conv1_1 = VGG_Conv('conv1_1') + self.conv1_2 = VGG_Conv('conv1_2') + self.pool1 = nn.MaxPool2d(2, 2) + self.conv2_1 = VGG_Conv('conv2_1') + self.conv2_2 = VGG_Conv('conv2_2') + self.pool2 = nn.MaxPool2d(2, 2) + self.conv3_1 = VGG_Conv('conv3_1') + self.conv3_2 = VGG_Conv('conv3_2') + self.conv3_3 = VGG_Conv('conv3_3') + self.pool3 = nn.MaxPool2d(2, 2) + self.conv4_1 = VGG_Conv('conv4_1') + self.conv4_2 = VGG_Conv('conv4_2') + self.conv4_3 = VGG_Conv('conv4_3') + self.pool4 = nn.MaxPool2d(2, 2) + self.conv5_1 = VGG_Conv('conv5_1') + self.conv5_2 = VGG_Conv('conv5_2') + self.conv5_3 = VGG_Conv('conv5_3') + self.pool5 = nn.MaxPool2d(2, 2) + self.merging1 = self.merging(i=2) + self.merging2 = self.merging(i=3) + self.merging3 = self.merging(i=4) + self.last_bn = nn.GroupNorm(16, 32) + self.conv_last = nn.Conv2d(32, 32, kernel_size=3, stride=1, has_bias=True, weight_init='XavierUniform') + self.inside_score_conv = nn.Conv2d(32, 1, kernel_size=1, stride=1, has_bias=True, + weight_init='XavierUniform') + self.side_v_angle_conv = nn.Conv2d(32, 2, kernel_size=1, stride=1, has_bias=True, + weight_init='XavierUniform') + self.side_v_coord_conv = nn.Conv2d(32, 4, kernel_size=1, stride=1, has_bias=True, + weight_init='XavierUniform') + self.op_concat = P.Concat(axis=1) + self.relu = P.ReLU() def merging(self, i=2): """ @@ -119,141 +198,190 @@ class AdvancedEast(nn.Cell): def construct(self, x): """ - forward func + forward func """ - f4 = self.conv1_1(x) - f4 = self.conv1_2(f4) - f4 = self.pool1(f4) - f4 = self.conv2_1(f4) - f4 = self.conv2_2(f4) - f4 = self.pool2(f4) - f3 = self.conv3_1(f4) - f3 = self.conv3_2(f3) - f3 = self.conv3_3(f3) - f3 = self.pool3(f3) - f2 = self.conv4_1(f3) - f2 = self.conv4_2(f2) - f2 = self.conv4_3(f2) - f2 = self.pool4(f2) - f1 = self.conv5_1(f2) - f1 = self.conv5_2(f1) - f1 = self.conv5_3(f1) - f1 = self.pool5(f1) - h1 = f1 - _, _, h_, w_ = P.Shape()(h1) - H1 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h1) - concat1 = self.op_concat((H1, f2)) - h2 = self.merging1(concat1) - _, _, h_, w_ = P.Shape()(h2) - H2 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h2) - concat2 = self.op_concat((H2, f3)) - h3 = self.merging2(concat2) - _, _, h_, w_ = P.Shape()(h3) - H3 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h3) - concat3 = self.op_concat((H3, f4)) - h4 = self.merging3(concat3) - before_output = self.relu(self.last_bn(self.conv_last(h4))) - inside_score = self.inside_score_conv(before_output) - side_v_angle = self.side_v_angle_conv(before_output) - side_v_coord = self.side_v_coord_conv(before_output) - east_detect = self.op_concat((inside_score, side_v_coord, side_v_angle)) + if self.device_target == 'GPU': + l2, l3, l4, l5 = self.vgg16(x) + h = l5 + + _, _, h_, w_ = P.Shape()(h) + g = ResizeNearestNeighbor((h_ * 2, w_ * 2))(h) + c = self.cat((g, l4)) + + c = self.bn1(c) + c = self.conv1(c) + c = self.relu1(c) + + h = self.bn2(c) + h = self.conv2(h) + h = self.relu2(h) + + _, _, h_, w_ = P.Shape()(h) + g = ResizeNearestNeighbor((h_ * 2, w_ * 2))(h) + c = self.cat((g, l3)) + + c = self.bn3(c) + c = self.conv3(c) + c = self.relu3(c) + + h = self.bn4(c) + h = self.conv4(h) # bs 64 w/8 h/8 + h = self.relu4(h) + + _, _, h_, w_ = P.Shape()(h) + g = ResizeNearestNeighbor((h_ * 2, w_ * 2))(h) + c = self.cat((g, l2)) + + c = self.bn5(c) + c = self.conv5(c) + c = self.relu5(c) + + h = self.bn6(c) + h = self.conv6(h) # bs 32 w/4 h/4 + h = self.relu6(h) + + g = self.bn7(h) + g = self.conv7(g) # bs 32 w/4 h/4 + g = self.relu7(g) + # get output + + inside_score = self.conv8(g) + side_v_code = self.conv9(g) + side_v_coord = self.conv10(g) + east_detect = self.cat((inside_score, side_v_code, side_v_coord)) + else: + f4 = self.conv1_1(x) + f4 = self.conv1_2(f4) + f4 = self.pool1(f4) + f4 = self.conv2_1(f4) + f4 = self.conv2_2(f4) + f4 = self.pool2(f4) + f3 = self.conv3_1(f4) + f3 = self.conv3_2(f3) + f3 = self.conv3_3(f3) + f3 = self.pool3(f3) + f2 = self.conv4_1(f3) + f2 = self.conv4_2(f2) + f2 = self.conv4_3(f2) + f2 = self.pool4(f2) + f1 = self.conv5_1(f2) + f1 = self.conv5_2(f1) + f1 = self.conv5_3(f1) + f1 = self.pool5(f1) + h1 = f1 + _, _, h_, w_ = P.Shape()(h1) + H1 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h1) + concat1 = self.op_concat((H1, f2)) + h2 = self.merging1(concat1) + _, _, h_, w_ = P.Shape()(h2) + H2 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h2) + concat2 = self.op_concat((H2, f3)) + h3 = self.merging2(concat2) + _, _, h_, w_ = P.Shape()(h3) + H3 = P.ResizeNearestNeighbor((h_ * 2, w_ * 2))(h3) + concat3 = self.op_concat((H3, f4)) + h4 = self.merging3(concat3) + before_output = self.relu(self.last_bn(self.conv_last(h4))) + inside_score = self.inside_score_conv(before_output) + side_v_angle = self.side_v_angle_conv(before_output) + side_v_coord = self.side_v_coord_conv(before_output) + east_detect = self.op_concat((inside_score, side_v_coord, side_v_angle)) + return east_detect -def dice_loss(gt_score, pred_score): - """dice_loss1""" - inter = P.ReduceSum()(gt_score * pred_score) - union = P.ReduceSum()(gt_score) + P.ReduceSum()(pred_score) + 1e-5 - return 1. - (2 * (inter / union)) - - -def dice_loss2(gt_score, pred_score, mask): - """dice_loss2""" - inter = P.ReduceSum()(gt_score * pred_score * mask) - union = P.ReduceSum()(gt_score * mask) + P.ReduceSum()(pred_score * mask) + 1e-5 - return 1. - (2 * (inter / union)) - - -def quad_loss(y_true, y_pred, - lambda_inside_score_loss=0.2, - lambda_side_vertex_code_loss=0.1, - lambda_side_vertex_coord_loss=1.0, - epsilon=1e-4): - """quad loss""" - y_true = P.Transpose()(y_true, (0, 2, 3, 1)) - y_pred = P.Transpose()(y_pred, (0, 2, 3, 1)) - logits = y_pred[:, :, :, :1] - labels = y_true[:, :, :, :1] - predicts = P.Sigmoid()(logits) - inside_score_loss = dice_loss(labels, predicts) - inside_score_loss = inside_score_loss * lambda_inside_score_loss - # loss for side_vertex_code - vertex_logitsp = P.Sigmoid()(y_pred[:, :, :, 1:2]) - vertex_labelsp = y_true[:, :, :, 1:2] - vertex_logitsn = P.Sigmoid()(y_pred[:, :, :, 2:3]) - vertex_labelsn = y_true[:, :, :, 2:3] - labels2 = y_true[:, :, :, 1:2] - side_vertex_code_lossp = dice_loss2(vertex_labelsp, vertex_logitsp, labels) - side_vertex_code_lossn = dice_loss2(vertex_labelsn, vertex_logitsn, labels2) - side_vertex_code_loss = (side_vertex_code_lossp + side_vertex_code_lossn) * lambda_side_vertex_code_loss - # loss for side_vertex_coord delta - g_hat = y_pred[:, :, :, 3:] # N*W*H*8 - g_true = y_true[:, :, :, 3:] - vertex_weights = P.Cast()(P.Equal()(y_true[:, :, :, 1], 1), mindspore.float32) - # vertex_weights=y_true[:, :, :,1] - pixel_wise_smooth_l1norm = smooth_l1_loss(g_hat, g_true, vertex_weights) - side_vertex_coord_loss = P.ReduceSum()(pixel_wise_smooth_l1norm) / ( - P.ReduceSum()(vertex_weights) + epsilon) - side_vertex_coord_loss = side_vertex_coord_loss * lambda_side_vertex_coord_loss - # print(inside_score_loss, side_vertex_code_loss, side_vertex_coord_loss) - return inside_score_loss + side_vertex_code_loss + side_vertex_coord_loss - - -def smooth_l1_loss(prediction_tensor, target_tensor, weights): - """smooth l1 loss""" - n_q = P.Reshape()(quad_norm(target_tensor), weights.shape) - diff = P.SmoothL1Loss()(prediction_tensor, target_tensor) - pixel_wise_smooth_l1norm = P.ReduceSum()(diff, -1) / n_q * weights - return pixel_wise_smooth_l1norm - - -def quad_norm(g_true, epsilon=1e-4): - """ quad norm""" - shape = g_true.shape - delta_xy_matrix = P.Reshape()(g_true, (shape[0] * shape[1] * shape[2], 2, 2)) - diff = delta_xy_matrix[:, 0:1, :] - delta_xy_matrix[:, 1:2, :] - square = diff * diff - distance = P.Sqrt()(P.ReduceSum()(square, -1)) - distance = distance * 4.0 - distance = distance + epsilon - return P.Reshape()(distance, (shape[0], shape[1], shape[2])) - class EastWithLossCell(nn.Cell): - """get loss cell""" + """ + loss + """ - def __init__(self, network, config=None): + def __init__(self, network): super(EastWithLossCell, self).__init__() self.East_network = network self.cat = P.Concat(axis=1) + def dice_loss(self, gt_score, pred_score): + """dice_loss1""" + inter = P.ReduceSum()(gt_score * pred_score) + union = P.ReduceSum()(gt_score) + P.ReduceSum()(pred_score) + 1e-5 + return 1. - (2 * (inter / union)) + + def dice_loss2(self, gt_score, pred_score, mask): + """dice_loss2""" + inter = P.ReduceSum()(gt_score * pred_score * mask) + union = P.ReduceSum()(gt_score * mask) + P.ReduceSum()(pred_score * mask) + 1e-5 + return 1. - (2 * (inter / union)) + + def quad_loss(self, y_true, y_pred, + lambda_inside_score_loss=0.2, + lambda_side_vertex_code_loss=0.1, + lambda_side_vertex_coord_loss=1.0, + epsilon=1e-4): + """quad loss""" + y_true = P.Transpose()(y_true, (0, 2, 3, 1)) + y_pred = P.Transpose()(y_pred, (0, 2, 3, 1)) + logits = y_pred[:, :, :, :1] + labels = y_true[:, :, :, :1] + predicts = P.Sigmoid()(logits) + inside_score_loss = self.dice_loss(labels, predicts) + inside_score_loss = inside_score_loss * lambda_inside_score_loss + # loss for side_vertex_code + vertex_logitsp = P.Sigmoid()(y_pred[:, :, :, 1:2]) + vertex_labelsp = y_true[:, :, :, 1:2] + vertex_logitsn = P.Sigmoid()(y_pred[:, :, :, 2:3]) + vertex_labelsn = y_true[:, :, :, 2:3] + labels2 = y_true[:, :, :, 1:2] + side_vertex_code_lossp = self.dice_loss2(vertex_labelsp, vertex_logitsp, labels) + side_vertex_code_lossn = self.dice_loss2(vertex_labelsn, vertex_logitsn, labels2) + side_vertex_code_loss = (side_vertex_code_lossp + side_vertex_code_lossn) * lambda_side_vertex_code_loss + # loss for side_vertex_coord delta + g_hat = y_pred[:, :, :, 3:] # N*W*H*8 + g_true = y_true[:, :, :, 3:] + vertex_weights = P.Cast()(P.Equal()(y_true[:, :, :, 1], 1), mindspore.float32) + + pixel_wise_smooth_l1norm = self.smooth_l1_loss(g_hat, g_true, vertex_weights) + side_vertex_coord_loss = P.ReduceSum()(pixel_wise_smooth_l1norm) / ( + P.ReduceSum()(vertex_weights) + epsilon) + side_vertex_coord_loss = side_vertex_coord_loss * lambda_side_vertex_coord_loss + return inside_score_loss + side_vertex_code_loss + side_vertex_coord_loss + + def smooth_l1_loss(self, prediction_tensor, target_tensor, weights): + """smooth l1 loss""" + n_q = P.Reshape()(self.quad_norm(target_tensor), weights.shape) + diff = P.SmoothL1Loss()(prediction_tensor, target_tensor) + pixel_wise_smooth_l1norm = P.ReduceSum()(diff, -1) / n_q * weights + return pixel_wise_smooth_l1norm + + def quad_norm(self, g_true, epsilon=1e-4): + """ quad norm""" + shape = g_true.shape + delta_xy_matrix = P.Reshape()(g_true, (shape[0] * shape[1] * shape[2], 2, 2)) + diff = delta_xy_matrix[:, 0:1, :] - delta_xy_matrix[:, 1:2, :] + square = diff * diff + distance = P.Sqrt()(P.ReduceSum()(square, -1)) + distance = distance * 4.0 + distance = distance + epsilon + return P.Reshape()(distance, (shape[0], shape[1], shape[2])) + def construct(self, image, label): y_pred = self.East_network(image) - loss = quad_loss(label, y_pred) + loss = self.quad_loss(label, y_pred) return loss class TrainStepWrap(nn.Cell): - """train net warper""" - def __init__(self, network, steps, config=None): + """ + train net + """ + + def __init__(self, network): super(TrainStepWrap, self).__init__() self.network = network self.network.set_train() self.weights = ParameterTuple(network.trainable_params()) - self.optimizer = AdamWeightDecay(self.weights, learning_rate=cfg.learning_rate - , eps=1e-7, weight_decay=cfg.decay) self.grad = C.GradOperation(get_by_list=True, sens_param=True) - self.sens = 3.0 + self.sens = 1.0 def construct(self, image, label): weights = self.weights @@ -263,11 +391,11 @@ class TrainStepWrap(nn.Cell): return F.depend(loss, self.optimizer(grads)) -def get_AdvancedEast_net(configure=None, steps=1, mode=True): +def get_AdvancedEast_net(args): """ Get network of wide&deep model. """ - AdvancedEast_net = AdvancedEast() - loss_net = EastWithLossCell(AdvancedEast_net, configure) - train_net = TrainStepWrap(loss_net, steps, configure) + AdvancedEast_net = AdvancedEast(args) + loss_net = EastWithLossCell(AdvancedEast_net) + train_net = TrainStepWrap(loss_net) return loss_net, train_net diff --git a/model_zoo/research/cv/advanced_east/src/preprocess.py b/model_zoo/research/cv/advanced_east/src/preprocess.py index f0d23a9ccba..e749566a366 100644 --- a/model_zoo/research/cv/advanced_east/src/preprocess.py +++ b/model_zoo/research/cv/advanced_east/src/preprocess.py @@ -72,7 +72,6 @@ def reorder_vertexes(xy_list): b_mid = xy_list[first_v, 1] - k[k_mid] * xy_list[first_v, 0] second_v, fourth_v = 0, 0 for index, i in zip(others, range(len(others))): - # delta = y - (k * x + b) delta_y = xy_list[index, 1] - (k[k_mid] * xy_list[index, 0] + b_mid) if delta_y > 0: second_v = index @@ -140,7 +139,7 @@ def preprocess(): if os.path.exists(os.path.join(train_image_dir, o_img_fname)) and \ os.path.exists(os.path.join(show_gt_image_dir, o_img_fname)): continue - # d_wight, d_height = resize_image(im) + d_wight, d_height = cfg.max_train_img_size, cfg.max_train_img_size scale_ratio_w = d_wight / im.width scale_ratio_h = d_height / im.height @@ -242,7 +241,7 @@ def preprocess_size(width=256): if os.path.exists(os.path.join(train_image_dir, o_img_fname)) and \ os.path.exists(os.path.join(show_gt_image_dir, o_img_fname)): continue - # d_wight, d_height = resize_image(im) + d_wight, d_height = width, width scale_ratio_w = d_wight / im.width scale_ratio_h = d_height / im.height @@ -303,7 +302,6 @@ def preprocess_size(width=256): train_label_list = os.listdir(train_label_dir) print('found %d train labels.' % len(train_label_list)) - # random.shuffle(train_val_set) val_count = int(cfg.validation_split_ratio * len(train_val_set)) with open(os.path.join(data_dir, cfg.val_fname_var + str(width) + '.txt'), 'a') as f_val: f_val.writelines(train_val_set[:val_count]) diff --git a/model_zoo/research/cv/advanced_east/src/score.py b/model_zoo/research/cv/advanced_east/src/score.py index 1ac0e79c82a..0608a265dd0 100644 --- a/model_zoo/research/cv/advanced_east/src/score.py +++ b/model_zoo/research/cv/advanced_east/src/score.py @@ -21,9 +21,6 @@ from shapely.geometry import Polygon from src.config import config as cfg from src.nms import nms - -# from src.preprocess import resize_image - class Averager(): """Compute average for torch.Tensor, used for loss average.""" @@ -99,12 +96,10 @@ class eval_pre_rec_f1(): return 0, 0, 0 quad_flag = np.zeros(num_quads) gt_flag = np.zeros(num_gts) - # print(num_quads, '-------', num_gts) quad_scores_no_zero = np.array(quad_scores_no_zero) scores_idx = np.argsort(quad_scores_no_zero)[::-1] for i in range(num_quads): idx = scores_idx[i] - # score = quad_scores_no_zero[idx] geo = quad_after_nms_no_zero[idx] for j in range(num_gts): if gt_flag[j] == 0: @@ -122,14 +117,12 @@ class eval_pre_rec_f1(): f1_score = 0 else: f1_score = 2 * pre * rec / (pre + rec) - # print(pre, '---', rec, '---', f1_score) return pre, rec, f1_score def add(self, out, gt_xy_list): """`add`""" self.img_num += len(gt_xy_list) ys = out.asnumpy() # (N, 7, 64, 64) - print(ys.shape) if ys.shape[1] == 7: ys = ys.transpose((0, 2, 3, 1)) # NCHW->NHWC for y, gt_xy in zip(ys, gt_xy_list): @@ -137,7 +130,6 @@ class eval_pre_rec_f1(): cond = np.greater_equal(y[:, :, 0], self.pixel_threshold) activation_pixels = np.where(cond) quad_scores, quad_after_nms = nms(y, activation_pixels) - print(quad_scores) lens = len(quad_after_nms) if lens == 0 or (sum(sum(quad_scores)) == 0): if not cfg.quiet: @@ -145,7 +137,6 @@ class eval_pre_rec_f1(): continue else: pre, rec, f1_score = self.eval_one(quad_scores, quad_after_nms, gt_xy) - print(pre, rec, f1_score) self.pre += pre self.rec += rec self.f1_score += f1_score diff --git a/model_zoo/research/cv/advanced_east/src/vgg.py b/model_zoo/research/cv/advanced_east/src/vgg.py new file mode 100644 index 00000000000..6f608aeed71 --- /dev/null +++ b/model_zoo/research/cv/advanced_east/src/vgg.py @@ -0,0 +1,146 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Image classifiation. +""" +import math +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.common import initializer as init +from mindspore.common.initializer import initializer, HeNormal + +from src.config import config +npy = config['vgg_npy'] + +def _make_layer(base, args, batch_norm): + """Make stage network of VGG.""" + layers = [] + in_channels = 3 + layer = [] + for v in base: + if v == 'M': + layer += [nn.MaxPool2d(kernel_size=2, stride=2)] + layers.append(layer) + layer = [] + else: + weight = 'ones' + if args.initialize_mode == "XavierUniform": + weight_shape = (v, in_channels, 3, 3) + weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() + + conv2d = nn.Conv2d(in_channels=in_channels, + out_channels=v, + kernel_size=3, + padding=args.padding, + pad_mode=args.pad_mode, + has_bias=args.has_bias, + weight_init=weight) + if batch_norm: + layer += [conv2d, nn.BatchNorm2d(v), nn.ReLU()] + else: + layer += [conv2d, nn.ReLU()] + in_channels = v + layer1 = nn.SequentialCell(layers[0]) + layer2 = nn.SequentialCell(layers[1]) + layer3 = nn.SequentialCell(layers[2]) + layer4 = nn.SequentialCell(layers[3]) + layer5 = nn.SequentialCell(layers[4]) + + return layer1, layer2, layer3, layer4, layer5 + + +class Vgg(nn.Cell): + """ + VGG network definition. + + Args: + base (list): Configuration for different layers, mainly the channel number of Conv layer. + num_classes (int): Class numbers. Default: 1000. + batch_norm (bool): Whether to do the batchnorm. Default: False. + batch_size (int): Batch size. Default: 1. + + Returns: + Tensor, infer output tensor. + + Examples: + >>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + >>> num_classes=1000, batch_norm=False, batch_size=1) + """ + + def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train"): + super(Vgg, self).__init__() + _ = batch_size + self.layer1, self.layer2, self.layer3, self.layer4, self.layer5 = _make_layer(base, args, batch_norm=batch_norm) + if args.initialize_mode == "KaimingNormal": + #default_recurisive_init(self) + self.custom_init_weight() + + def construct(self, x): + l1 = self.layer1(x) + l2 = self.layer2(l1) + l3 = self.layer3(l2) + l4 = self.layer4(l3) + l5 = self.layer5(l4) + return l2, l3, l4, l5 + + def custom_init_weight(self): + """ + Init the weight of Conv2d and Dense in the net. + """ + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.set_data(init.initializer( + HeNormal(negative_slope=math.sqrt(5), mode='fan_out', nonlinearity='relu'), + cell.weight.shape, cell.weight.dtype)) + if cell.bias is not None: + cell.bias.set_data(init.initializer( + 'zeros', cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, nn.Dense): + cell.weight.set_data(init.initializer( + init.Normal(0.01), cell.weight.shape, cell.weight.dtype)) + if cell.bias is not None: + cell.bias.set_data(init.initializer( + 'zeros', cell.bias.shape, cell.bias.dtype)) + + +cfg = { + '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +def vgg16(num_classes=1000, args=None, phase="train"): + """ + Get Vgg16 neural network with batch normalization. + + Args: + num_classes (int): Class numbers. Default: 1000. + args(namespace): param for net init. + phase(str): train or test mode. + + Returns: + Cell, cell instance of Vgg16 neural network with batch normalization. + + Examples: + >>> vgg16(num_classes=1000, args=args) + """ + + if args is None: + from src.config import cifar_cfg + args = cifar_cfg + net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase) + return net diff --git a/model_zoo/research/cv/advanced_east/train.py b/model_zoo/research/cv/advanced_east/train.py index 0f2c35047c8..58ac67a48a9 100644 --- a/model_zoo/research/cv/advanced_east/train.py +++ b/model_zoo/research/cv/advanced_east/train.py @@ -18,9 +18,11 @@ import argparse import datetime import os +import time +import ast from mindspore import context, Model -from mindspore.communication.management import init, get_group_size +from mindspore.communication.management import init, get_group_size, get_rank from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.context import ParallelMode from mindspore.train.serialization import load_param_into_net, load_checkpoint @@ -34,15 +36,16 @@ from src.config import config as cfg set_seed(1) -def parse_args(cloud_args=None): +def parse_args(): """parameters""" parser = argparse.ArgumentParser('mindspore adveast training') parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], help='device where the code will be implemented. (Default: Ascend)') - parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)') + parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend.') # network related - parser.add_argument('--pre_trained', default=False, type=bool, help='model_path, local pretrained model to load') + parser.add_argument('--pre_trained', default=False, type=ast.literal_eval, + help='model_path, local pretrained model to load') # logging and checkpoint related parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location') @@ -55,27 +58,21 @@ def parse_args(cloud_args=None): parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') args_opt = parser.parse_args() - args_opt.initial_epoch = cfg.initial_epoch args_opt.epoch_num = cfg.epoch_num - args_opt.learning_rate = cfg.learning_rate - args_opt.decay = cfg.decay args_opt.batch_size = cfg.batch_size - args_opt.total_train_img = cfg.total_img * (1 - cfg.validation_split_ratio) - args_opt.total_valid_img = cfg.total_img * cfg.validation_split_ratio args_opt.ckpt_save_max = cfg.ckpt_save_max args_opt.data_dir = cfg.data_dir args_opt.mindsrecord_train_file = cfg.mindsrecord_train_file args_opt.mindsrecord_test_file = cfg.mindsrecord_test_file - args_opt.train_image_dir_name = cfg.train_image_dir_name - args_opt.train_label_dir_name = cfg.train_label_dir_name - args_opt.results_dir = cfg.results_dir args_opt.last_model_name = cfg.last_model_name args_opt.saved_model_file_path = cfg.saved_model_file_path + args_opt.ds_sink_mode = cfg.ds_sink_mode return args_opt if __name__ == '__main__': args = parse_args() + device_num = int(os.environ.get("DEVICE_NUM", 1)) context.set_context(mode=context.GRAPH_MODE) workers = 32 @@ -86,6 +83,7 @@ if __name__ == '__main__': elif args.device_target == "GPU": context.set_context(device_target=args.device_target) init() + args.rank = get_rank() args.group_size = get_group_size() device_num = args.group_size @@ -114,7 +112,7 @@ if __name__ == '__main__': args.rank_save_ckpt_flag = 1 # get network and init - loss_net, train_net = get_AdvancedEast_net() + loss_net, train_net = get_AdvancedEast_net(args) loss_net.add_flags_recursive(fp32=True) train_net.set_train(False) # pre_trained @@ -136,53 +134,81 @@ if __name__ == '__main__': train_dataset448, batch_num448 = load_adEAST_dataset(mindrecordfile448, batch_size=2, device_num=device_num, rank_id=args.rank, is_training=True, num_parallel_workers=workers) - - train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate - , eps=1e-7, weight_decay=cfg.decay) + start = time.time() + learning_rate = cfg.learning_rate_ascend if args.device_target == 'Ascend' else cfg.learning_rate_gpu + decay = cfg.decay_ascend if args.device_target == 'Ascend' else cfg.decay_gpu + # train model using the images resized to 256 + train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=learning_rate + , eps=1e-7, weight_decay=decay) model = Model(train_net) time_cb = TimeMonitor(data_size=batch_num256) loss_cb = LossMonitor(per_print_times=batch_num256) - callbacks = [time_cb, loss_cb] - if args.rank_save_ckpt_flag: - ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num256, - keep_checkpoint_max=args.ckpt_save_max) - save_ckpt_path = args.saved_model_file_path - ckpt_cb = ModelCheckpoint(config=ckpt_config, - directory=save_ckpt_path, - prefix='Epoch_A{}'.format(args.rank)) - callbacks.append(ckpt_cb) - model.train(epoch=cfg.epoch_num, train_dataset=train_dataset256, callbacks=callbacks, dataset_sink_mode=False) - model.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate - , eps=1e-7, weight_decay=cfg.decay) + callbacks = [] + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num256, + keep_checkpoint_max=args.ckpt_save_max) + save_ckpt_path = args.saved_model_file_path + ckpt_cb = ModelCheckpoint(config=ckpt_config, + directory=save_ckpt_path, + prefix='Epoch_A{}'.format(args.rank)) + if args.is_distributed & args.is_save_on_master: + if args.rank == 0: + callbacks.extend([time_cb, loss_cb, ckpt_cb]) + model.train(args.epoch_num, train_dataset=train_dataset256, + callbacks=callbacks, dataset_sink_mode=args.ds_sink_mode) + else: + callbacks.extend([time_cb, loss_cb, ckpt_cb]) + model.train(args.epoch_num, train_dataset=train_dataset256, + callbacks=callbacks, dataset_sink_mode=args.ds_sink_mode) + print(time.time() - start) + # train model using the images resized to 384 + model.optimizer = AdamWeightDecay(train_net.weights, learning_rate=learning_rate + , eps=1e-7, weight_decay=decay) + train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=learning_rate + , eps=1e-7, weight_decay=decay) + model = Model(train_net) + time_cb = TimeMonitor(data_size=batch_num384) loss_cb = LossMonitor(per_print_times=batch_num384) - callbacks = [time_cb, loss_cb] - if args.rank_save_ckpt_flag: - ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num384, - keep_checkpoint_max=args.ckpt_save_max) - save_ckpt_path = args.saved_model_file_path - ckpt_cb = ModelCheckpoint(config=ckpt_config, - directory=save_ckpt_path, - prefix='Epoch_B{}'.format(args.rank)) - callbacks.append(ckpt_cb) - train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate - , eps=1e-7, weight_decay=cfg.decay) + callbacks = [] + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num384, + keep_checkpoint_max=args.ckpt_save_max) + ckpt_cb = ModelCheckpoint(config=ckpt_config, + directory=save_ckpt_path, + prefix='Epoch_B{}'.format(args.rank)) + if args.is_distributed & args.is_save_on_master: + if args.rank == 0: + callbacks.extend([time_cb, loss_cb, ckpt_cb]) + model.train(args.epoch_num, train_dataset=train_dataset384, + callbacks=callbacks, dataset_sink_mode=args.ds_sink_mode) + else: + callbacks.extend([time_cb, loss_cb, ckpt_cb]) + model.train(args.epoch_num, train_dataset=train_dataset384, + callbacks=callbacks, dataset_sink_mode=args.ds_sink_mode) + print(time.time() - start) + + # train model using the images resized to 448 + model.optimizer = AdamWeightDecay(train_net.weights, learning_rate=learning_rate + , eps=1e-7, weight_decay=decay) + train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=learning_rate + , eps=1e-7, weight_decay=decay) model = Model(train_net) - model.train(epoch=cfg.epoch_num, train_dataset=train_dataset384, callbacks=callbacks, dataset_sink_mode=False) - model.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate - , eps=1e-7, weight_decay=cfg.decay) + time_cb = TimeMonitor(data_size=batch_num448) loss_cb = LossMonitor(per_print_times=batch_num448) - callbacks = [time_cb, loss_cb] - if args.rank_save_ckpt_flag: - ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num448, - keep_checkpoint_max=args.ckpt_save_max) - save_ckpt_path = args.saved_model_file_path - ckpt_cb = ModelCheckpoint(config=ckpt_config, - directory=save_ckpt_path, - prefix='Epoch_C{}'.format(args.rank)) - callbacks.append(ckpt_cb) - train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate - , eps=1e-7, weight_decay=cfg.decay) - model = Model(train_net) - model.train(epoch=cfg.epoch_num, train_dataset=train_dataset448, callbacks=callbacks, dataset_sink_mode=False) + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * batch_num448, + keep_checkpoint_max=args.ckpt_save_max) + callbacks = [] + ckpt_cb = ModelCheckpoint(config=ckpt_config, + directory=save_ckpt_path, + prefix='Epoch_C{}'.format(args.rank)) + if args.is_distributed & args.is_save_on_master: + if args.rank == 0: + callbacks.extend([time_cb, loss_cb, ckpt_cb]) + model.train(args.epoch_num, train_dataset=train_dataset448, + callbacks=callbacks, dataset_sink_mode=args.ds_sink_mode) + else: + callbacks.extend([time_cb, loss_cb, ckpt_cb]) + model.train(args.epoch_num, train_dataset=train_dataset448, + callbacks=callbacks, dataset_sink_mode=args.ds_sink_mode) + + print(time.time() - start) diff --git a/model_zoo/research/cv/advanced_east/train_mindrecord.py b/model_zoo/research/cv/advanced_east/train_mindrecord.py index b754f3f8cc2..8460cb55221 100644 --- a/model_zoo/research/cv/advanced_east/train_mindrecord.py +++ b/model_zoo/research/cv/advanced_east/train_mindrecord.py @@ -18,6 +18,7 @@ import argparse import datetime import os +import ast from mindspore import context, Model from mindspore.common import set_seed @@ -40,12 +41,13 @@ def parse_args(cloud_args=None): parser = argparse.ArgumentParser('mindspore adveast training') parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], help='device where the code will be implemented. (Default: Ascend)') - parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)') + parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend.') # network related - parser.add_argument('--pre_trained', default=False, type=bool, help='model_path, local pretrained model to load') - parser.add_argument('--data_path', default='/disk1/ade/icpr/advanced-east-val_256.mindrecord', type=str) - parser.add_argument('--pre_trained_ckpt', default='/disk1/adeast/scripts/1.ckpt', type=str) + parser.add_argument('--pre_trained', default=False, type=ast.literal_eval, + help='model_path, local pretrained model to load') + parser.add_argument('--data_path', type=str) + parser.add_argument('--pre_trained_ckpt', type=str) # logging and checkpoint related parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location') @@ -58,19 +60,12 @@ def parse_args(cloud_args=None): parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') args_opt = parser.parse_args() - args_opt.initial_epoch = cfg.initial_epoch args_opt.epoch_num = cfg.epoch_num - args_opt.learning_rate = cfg.learning_rate - args_opt.decay = cfg.decay args_opt.batch_size = cfg.batch_size - args_opt.total_train_img = cfg.total_img * (1 - cfg.validation_split_ratio) - args_opt.total_valid_img = cfg.total_img * cfg.validation_split_ratio args_opt.ckpt_save_max = cfg.ckpt_save_max args_opt.data_dir = cfg.data_dir args_opt.mindsrecord_train_file = cfg.mindsrecord_train_file args_opt.mindsrecord_test_file = cfg.mindsrecord_test_file - args_opt.train_image_dir_name = cfg.train_image_dir_name - args_opt.train_label_dir_name = cfg.train_label_dir_name args_opt.results_dir = cfg.results_dir args_opt.last_model_name = cfg.last_model_name args_opt.saved_model_file_path = cfg.saved_model_file_path @@ -123,16 +118,18 @@ if __name__ == '__main__': args.rank_save_ckpt_flag = 1 # get network and init - loss_net, train_net = get_AdvancedEast_net() + loss_net, train_net = get_AdvancedEast_net(args) loss_net.add_flags_recursive(fp32=True) train_net.set_train(True) # pre_trained if args.pre_trained: load_param_into_net(train_net, load_checkpoint(args.pre_trained_ckpt)) - # define callbacks - train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=cfg.learning_rate - , eps=1e-7, weight_decay=cfg.decay) + learning_rate = cfg.learning_rate_ascend if args.device_target == 'Ascend' else cfg.learning_rate_gpu + decay = cfg.decay_ascend if args.device_target == 'Ascend' else cfg.decay_gpu + + train_net.optimizer = AdamWeightDecay(train_net.weights, learning_rate=learning_rate + , eps=1e-7, weight_decay=decay) model = Model(train_net) time_cb = TimeMonitor(data_size=batch_num) loss_cb = LossMonitor(per_print_times=batch_num)