forked from mindspore-Ecosystem/mindspore
!10134 improve ncf accuracy.
From: @linqingke Reviewed-by: @wuxuejian,@liangchenghui Signed-off-by: @liangchenghui
This commit is contained in:
commit
8e6ac00f36
|
@ -37,7 +37,6 @@ config = ed({
|
|||
|
||||
# dataset for train
|
||||
"TRAIN_ROOT_DIR": "psenet/ic15/",
|
||||
"TRAIN_IS_TRANSFORM": True,
|
||||
"TRAIN_LONG_SIZE": 640,
|
||||
"TRAIN_MIN_SCALE": 0.4,
|
||||
"TRAIN_BATCH_SIZE": 4,
|
||||
|
|
|
@ -160,7 +160,7 @@ def shrink(bboxes, rate, max_shr=20):
|
|||
|
||||
class TrainDataset:
|
||||
def __init__(self):
|
||||
self.is_transform = config.TRAIN_IS_TRANSFORM
|
||||
self.is_transform = True
|
||||
self.img_size = config.TRAIN_LONG_SIZE
|
||||
self.kernel_num = config.KERNEL_NUM
|
||||
self.min_scale = config.TRAIN_MIN_SCALE
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lr generator for ncf"""
|
||||
import math
|
||||
|
||||
def _linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
def _cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
|
||||
base = float(current_step - warmup_steps) / float(decay_steps)
|
||||
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
|
||||
return learning_rate
|
||||
|
||||
def dynamic_lr(base_lr, total_steps, warmup_steps):
|
||||
"""dynamic learning rate generator"""
|
||||
lr = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr.append(_linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * 0.01))
|
||||
else:
|
||||
lr.append(_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
|
||||
|
||||
return lr
|
|
@ -26,6 +26,8 @@ from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
from src.lr_schedule import dynamic_lr
|
||||
|
||||
class DenseLayer(nn.Cell):
|
||||
"""
|
||||
Dense layer definition
|
||||
|
@ -223,14 +225,16 @@ class TrainStepWrap(nn.Cell):
|
|||
"""
|
||||
TrainStepWrap definition
|
||||
"""
|
||||
def __init__(self, network, sens=16384.0):
|
||||
def __init__(self, network, total_steps=1, sens=16384.0):
|
||||
super(TrainStepWrap, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_train()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
|
||||
lr = dynamic_lr(0.01, total_steps, 5000)
|
||||
self.optimizer = nn.Adam(self.weights,
|
||||
learning_rate=0.00382059,
|
||||
learning_rate=lr,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
eps=1e-8,
|
||||
|
|
|
@ -22,12 +22,15 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni
|
|||
from mindspore import context, Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_rank, get_group_size, init
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.ncf import NCFModel, NetWithLossClass, TrainStepWrap
|
||||
|
||||
from config import cfg
|
||||
|
||||
set_seed(1)
|
||||
|
||||
logging.set_verbosity(logging.INFO)
|
||||
|
||||
parser = argparse.ArgumentParser(description='NCF')
|
||||
|
@ -86,7 +89,7 @@ def test_train():
|
|||
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
|
||||
mf_dim=16)
|
||||
loss_net = NetWithLossClass(ncf_net)
|
||||
train_net = TrainStepWrap(loss_net)
|
||||
train_net = TrainStepWrap(loss_net, ds_train.get_dataset_size() * (epochs + 1))
|
||||
|
||||
train_net.set_train()
|
||||
|
||||
|
|
Loading…
Reference in New Issue