!7010 psenet 8p accuracy improve.

Merge pull request !7010 from linqingke/psenet
This commit is contained in:
mindspore-ci-bot 2020-09-30 09:07:19 +08:00 committed by Gitee
commit afb1a91568
5 changed files with 50 additions and 12 deletions

View File

@ -41,9 +41,9 @@ fi
python ${current_exec_path}/src/generate_hccn_file.py
export DEVICE_NUM=4
export RANK_SIZE=4
export RANK_TABLE_FILE=${current_exec_path}/rank_table_4p.json
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=${current_exec_path}/rank_table_8p.json
for((i=0; i<${DEVICE_NUM}; i++))
do

View File

@ -29,6 +29,12 @@ config = ed({
# neck
'NECK_OUT_CHANNEL': 256,
# lr
"BASE_LR": 2e-3,
"TRAIN_TOTAL_ITER": 58000,
"WARMUP_STEP": 620,
"WARMUP_RATIO": 1/3,
# dataset for train
"TRAIN_ROOT_DIR": 'psenet/ic15/',
"TRAIN_IS_TRANSFORM": True,
@ -37,9 +43,8 @@ config = ed({
"TRAIN_MIN_SCALE": 0.4,
"TRAIN_BUFFER_SIZE": 8,
"TRAIN_BATCH_SIZE": 4,
"TRAIN_REPEAT_NUM": 608*4,
"TRAIN_REPEAT_NUM": 1800,
"TRAIN_DROP_REMAINDER": True,
"TRAIN_TOTAL_ITER": 152000,
"TRAIN_MODEL_SAVE_PATH": './checkpoints/',
# dataset for test

View File

@ -17,7 +17,7 @@
import os
import socket
RANK_TABLE_SAVE_PATH = './rank_table_4p.json'
RANK_TABLE_SAVE_PATH = './rank_table_8p.json'
def main():

View File

@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr generator for psenet"""
import math
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
base = float(current_step - warmup_steps) / float(decay_steps)
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate
def dynamic_lr(base_lr, total_steps, warmup_steps, warmup_ratio=1/3):
"""dynamic learning rate generator"""
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * warmup_ratio))
else:
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
return lr

View File

@ -14,7 +14,6 @@
# ============================================================================
import math
import argparse
import mindspore.nn as nn
from mindspore import context
@ -29,6 +28,7 @@ from src.config import config
from src.ETSNET.etsnet import ETSNet
from src.ETSNET.dice_loss import DiceLoss
from src.network_define import WithLossCell, TrainOneStepCell, LossCallBack
from src.lr_schedule import dynamic_lr
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('--run_distribute', default=False, action='store_true',
@ -41,10 +41,6 @@ args = parser.parse_args()
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
def lr_generator(start_lr, lr_scale, total_iters):
lrs = [start_lr * (lr_scale ** math.floor(cur_iter * 1.0 / (total_iters / 3))) for cur_iter in range(total_iters)]
return lrs
def train():
rank_id = 0
if args.run_distribute:
@ -67,7 +63,7 @@ def train():
criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE)
lrs = lr_generator(start_lr=1e-3, lr_scale=0.1, total_iters=config.TRAIN_TOTAL_ITER)
lrs = dynamic_lr(config.BASE_LR, config.TRAIN_TOTAL_ITER, config.WARMUP_STEP, config.WARMUP_RATIO)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lrs, momentum=0.99, weight_decay=5e-4)
# warp model