forked from mindspore-Ecosystem/mindspore
73 lines
3.7 KiB
Python
73 lines
3.7 KiB
Python
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Train NAML."""
|
|
import time
|
|
from mindspore import nn, load_checkpoint
|
|
import mindspore.common.dtype as mstype
|
|
from mindspore.common import set_seed
|
|
from mindspore.train.model import Model
|
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
|
from src.naml import NAML, NAMLWithLossCell
|
|
from src.option import get_args
|
|
from src.dataset import create_dataset, MINDPreprocess
|
|
from src.utils import process_data
|
|
from src.callback import Monitor
|
|
|
|
if __name__ == '__main__':
|
|
args = get_args("train")
|
|
set_seed(args.seed)
|
|
word_embedding = process_data(args)
|
|
net = NAML(args, word_embedding)
|
|
net_with_loss = NAMLWithLossCell(net)
|
|
if args.checkpoint_path is not None:
|
|
load_checkpoint(args.pretrain_checkpoint, net_with_loss)
|
|
mindpreprocess_train = MINDPreprocess(vars(args), dataset_path=args.train_dataset_path)
|
|
dataset = create_dataset(mindpreprocess_train, batch_size=args.batch_size, rank=args.rank,
|
|
group_size=args.device_num)
|
|
args.dataset_size = dataset.get_dataset_size()
|
|
args.print_times = min(args.dataset_size, args.print_times)
|
|
if args.weight_decay:
|
|
weight_params = list(filter(lambda x: 'weight' in x.name, net.trainable_params()))
|
|
other_params = list(filter(lambda x: 'weight' not in x.name, net.trainable_params()))
|
|
group_params = [{'params': weight_params, 'weight_decay': 1e-3},
|
|
{'params': other_params, 'weight_decay': 0.0},
|
|
{'order_params': net.trainable_params()}]
|
|
opt = nn.AdamWeightDecay(group_params, args.lr, beta1=args.beta1, beta2=args.beta2, eps=args.epsilon)
|
|
else:
|
|
opt = nn.Adam(net.trainable_params(), args.lr, beta1=args.beta1, beta2=args.beta2, eps=args.epsilon)
|
|
if args.mixed:
|
|
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=128.0, scale_factor=2, scale_window=10000)
|
|
net_with_loss.to_float(mstype.float16)
|
|
for _, cell in net_with_loss.cells_and_names():
|
|
if isinstance(cell, (nn.Embedding, nn.Softmax, nn.SoftmaxCrossEntropyWithLogits)):
|
|
cell.to_float(mstype.float32)
|
|
model = Model(net_with_loss, optimizer=opt, loss_scale_manager=loss_scale_manager)
|
|
else:
|
|
model = Model(net_with_loss, optimizer=opt)
|
|
cb = [Monitor(args)]
|
|
epochs = args.epochs
|
|
if args.sink_mode:
|
|
epochs = int(args.epochs * args.dataset_size / args.print_times)
|
|
start_time = time.time()
|
|
print("======================= Start Train ==========================", flush=True)
|
|
model.train(epochs, dataset, callbacks=cb, dataset_sink_mode=args.sink_mode, sink_size=args.print_times)
|
|
end_time = time.time()
|
|
print("==============================================================")
|
|
print("processor_name: {}".format(args.platform))
|
|
print("test_name: NAML")
|
|
print(f"model_name: NAML MIND{args.dataset}")
|
|
print("batch_size: {}".format(args.batch_size))
|
|
print("latency: {} s".format(end_time - start_time))
|