97 lines
3.9 KiB
Python
97 lines
3.9 KiB
Python
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
'''
|
|
##############train models#################
|
|
python train.py
|
|
'''
|
|
|
|
from mindspore import context, nn
|
|
from mindspore.train import Model
|
|
from mindspore.common import set_seed
|
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
|
|
|
from src.model_utils.device_adapter import get_device_id
|
|
from src.model_utils.config import config
|
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
|
from src.dataset import create_dataset
|
|
from src.musictagger import MusicTaggerCNN
|
|
from src.loss import BCELoss
|
|
|
|
|
|
def modelarts_pre_process():
|
|
pass
|
|
|
|
@moxing_wrapper(pre_process=modelarts_pre_process)
|
|
def train(model, dataset_direct, filename, columns_list, num_consumer=4,
|
|
batch=16, epoch=50, save_checkpoint_steps=2172, keep_checkpoint_max=50,
|
|
prefix="model", directory='./'):
|
|
"""
|
|
train network
|
|
"""
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps,
|
|
keep_checkpoint_max=keep_checkpoint_max)
|
|
ckpoint_cb = ModelCheckpoint(prefix=prefix,
|
|
directory=directory,
|
|
config=config_ck)
|
|
data_train = create_dataset(dataset_direct, filename, batch, columns_list,
|
|
num_consumer)
|
|
|
|
|
|
model.train(epoch, data_train, callbacks=[ckpoint_cb, \
|
|
LossMonitor(per_print_times=181), TimeMonitor()], dataset_sink_mode=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
set_seed(1)
|
|
|
|
context.set_context(device_target='Ascend', mode=context.GRAPH_MODE, device_id=get_device_id())
|
|
context.set_context(enable_auto_mixed_precision=config.mixed_precision)
|
|
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
|
kernel_size=[3, 3, 3, 3, 3],
|
|
padding=[0] * 5,
|
|
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
|
has_bias=True)
|
|
|
|
if config.pre_trained:
|
|
param_dict = load_checkpoint(config.checkpoint_path + '/' +
|
|
config.model_name)
|
|
load_param_into_net(network, param_dict)
|
|
|
|
net_loss = BCELoss()
|
|
|
|
network.set_train(True)
|
|
net_opt = nn.Adam(params=network.trainable_params(),
|
|
learning_rate=config.lr,
|
|
loss_scale=config.loss_scale)
|
|
|
|
loss_scale_manager = FixedLossScaleManager(loss_scale=config.loss_scale,
|
|
drop_overflow_update=False)
|
|
net_model = Model(network, net_loss, net_opt, loss_scale_manager=loss_scale_manager)
|
|
|
|
train(model=net_model,
|
|
dataset_direct=config.data_dir,
|
|
filename=config.train_filename,
|
|
columns_list=['feature', 'label'],
|
|
num_consumer=config.num_consumer,
|
|
batch=config.batch_size,
|
|
epoch=config.epoch_size,
|
|
save_checkpoint_steps=config.save_step,
|
|
keep_checkpoint_max=config.keep_checkpoint_max,
|
|
prefix=config.prefix,
|
|
directory=config.checkpoint_path) # + "_{}".format(get_device_id())
|
|
print("train success")
|