forked from mindspore-Ecosystem/mindspore
!7010 psenet 8p accuracy improve.
Merge pull request !7010 from linqingke/psenet
This commit is contained in:
commit
afb1a91568
|
@ -41,9 +41,9 @@ fi
|
|||
|
||||
python ${current_exec_path}/src/generate_hccn_file.py
|
||||
|
||||
export DEVICE_NUM=4
|
||||
export RANK_SIZE=4
|
||||
export RANK_TABLE_FILE=${current_exec_path}/rank_table_4p.json
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=${current_exec_path}/rank_table_8p.json
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
|
|
|
@ -29,6 +29,12 @@ config = ed({
|
|||
# neck
|
||||
'NECK_OUT_CHANNEL': 256,
|
||||
|
||||
# lr
|
||||
"BASE_LR": 2e-3,
|
||||
"TRAIN_TOTAL_ITER": 58000,
|
||||
"WARMUP_STEP": 620,
|
||||
"WARMUP_RATIO": 1/3,
|
||||
|
||||
# dataset for train
|
||||
"TRAIN_ROOT_DIR": 'psenet/ic15/',
|
||||
"TRAIN_IS_TRANSFORM": True,
|
||||
|
@ -37,9 +43,8 @@ config = ed({
|
|||
"TRAIN_MIN_SCALE": 0.4,
|
||||
"TRAIN_BUFFER_SIZE": 8,
|
||||
"TRAIN_BATCH_SIZE": 4,
|
||||
"TRAIN_REPEAT_NUM": 608*4,
|
||||
"TRAIN_REPEAT_NUM": 1800,
|
||||
"TRAIN_DROP_REMAINDER": True,
|
||||
"TRAIN_TOTAL_ITER": 152000,
|
||||
"TRAIN_MODEL_SAVE_PATH": './checkpoints/',
|
||||
|
||||
# dataset for test
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import os
|
||||
import socket
|
||||
|
||||
RANK_TABLE_SAVE_PATH = './rank_table_4p.json'
|
||||
RANK_TABLE_SAVE_PATH = './rank_table_8p.json'
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lr generator for psenet"""
|
||||
import math
|
||||
|
||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
|
||||
base = float(current_step - warmup_steps) / float(decay_steps)
|
||||
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
|
||||
return learning_rate
|
||||
|
||||
def dynamic_lr(base_lr, total_steps, warmup_steps, warmup_ratio=1/3):
|
||||
"""dynamic learning rate generator"""
|
||||
lr = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * warmup_ratio))
|
||||
else:
|
||||
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
|
||||
|
||||
return lr
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
|
||||
|
||||
import math
|
||||
import argparse
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
|
@ -29,6 +28,7 @@ from src.config import config
|
|||
from src.ETSNET.etsnet import ETSNet
|
||||
from src.ETSNET.dice_loss import DiceLoss
|
||||
from src.network_define import WithLossCell, TrainOneStepCell, LossCallBack
|
||||
from src.lr_schedule import dynamic_lr
|
||||
|
||||
parser = argparse.ArgumentParser(description='Hyperparams')
|
||||
parser.add_argument('--run_distribute', default=False, action='store_true',
|
||||
|
@ -41,10 +41,6 @@ args = parser.parse_args()
|
|||
set_seed(1)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
|
||||
|
||||
def lr_generator(start_lr, lr_scale, total_iters):
|
||||
lrs = [start_lr * (lr_scale ** math.floor(cur_iter * 1.0 / (total_iters / 3))) for cur_iter in range(total_iters)]
|
||||
return lrs
|
||||
|
||||
def train():
|
||||
rank_id = 0
|
||||
if args.run_distribute:
|
||||
|
@ -67,7 +63,7 @@ def train():
|
|||
|
||||
criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE)
|
||||
|
||||
lrs = lr_generator(start_lr=1e-3, lr_scale=0.1, total_iters=config.TRAIN_TOTAL_ITER)
|
||||
lrs = dynamic_lr(config.BASE_LR, config.TRAIN_TOTAL_ITER, config.WARMUP_STEP, config.WARMUP_RATIO)
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lrs, momentum=0.99, weight_decay=5e-4)
|
||||
|
||||
# warp model
|
||||
|
|
Loading…
Reference in New Issue