mindspore/model_zoo/official/recommend/naml/train.py

73 lines
3.7 KiB
Python

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train NAML."""
import time
from mindspore import nn, load_checkpoint
import mindspore.common.dtype as mstype
from mindspore.common import set_seed
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from src.naml import NAML, NAMLWithLossCell
from src.option import get_args
from src.dataset import create_dataset, MINDPreprocess
from src.utils import process_data
from src.callback import Monitor
if __name__ == '__main__':
args = get_args("train")
set_seed(args.seed)
word_embedding = process_data(args)
net = NAML(args, word_embedding)
net_with_loss = NAMLWithLossCell(net)
if args.checkpoint_path is not None:
load_checkpoint(args.pretrain_checkpoint, net_with_loss)
mindpreprocess_train = MINDPreprocess(vars(args), dataset_path=args.train_dataset_path)
dataset = create_dataset(mindpreprocess_train, batch_size=args.batch_size, rank=args.rank,
group_size=args.device_num)
args.dataset_size = dataset.get_dataset_size()
args.print_times = min(args.dataset_size, args.print_times)
if args.weight_decay:
weight_params = list(filter(lambda x: 'weight' in x.name, net.trainable_params()))
other_params = list(filter(lambda x: 'weight' not in x.name, net.trainable_params()))
group_params = [{'params': weight_params, 'weight_decay': 1e-3},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': net.trainable_params()}]
opt = nn.AdamWeightDecay(group_params, args.lr, beta1=args.beta1, beta2=args.beta2, eps=args.epsilon)
else:
opt = nn.Adam(net.trainable_params(), args.lr, beta1=args.beta1, beta2=args.beta2, eps=args.epsilon)
if args.mixed:
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=128.0, scale_factor=2, scale_window=10000)
net_with_loss.to_float(mstype.float16)
for _, cell in net_with_loss.cells_and_names():
if isinstance(cell, (nn.Embedding, nn.Softmax, nn.SoftmaxCrossEntropyWithLogits)):
cell.to_float(mstype.float32)
model = Model(net_with_loss, optimizer=opt, loss_scale_manager=loss_scale_manager)
else:
model = Model(net_with_loss, optimizer=opt)
cb = [Monitor(args)]
epochs = args.epochs
if args.sink_mode:
epochs = int(args.epochs * args.dataset_size / args.print_times)
start_time = time.time()
print("======================= Start Train ==========================", flush=True)
model.train(epochs, dataset, callbacks=cb, dataset_sink_mode=args.sink_mode, sink_size=args.print_times)
end_time = time.time()
print("==============================================================")
print("processor_name: {}".format(args.platform))
print("test_name: NAML")
print(f"model_name: NAML MIND{args.dataset}")
print("batch_size: {}".format(args.batch_size))
print("latency: {} s".format(end_time - start_time))