This commit is contained in:
huchunmei 2021-05-19 11:21:24 +08:00
parent c1b91ff791
commit 7e60bbe181
15 changed files with 490 additions and 124 deletions

View File

@ -97,12 +97,18 @@ SLOG_PRINT_TO_STDOUT=1 python eval.py --device_id 0
│ ├──dataset.py // creating dataset │ ├──dataset.py // creating dataset
│ ├──pre_process_data.py // pre-process dataset │ ├──pre_process_data.py // pre-process dataset
│ ├──musictagger.py // googlenet architecture │ ├──musictagger.py // googlenet architecture
│ ├──config.py // parameter configuration
│ ├──loss.py // loss function │ ├──loss.py // loss function
│ ├──tag.txt // tag for each number │ ├──tag.txt // tag for each number
| └─model_utils
| ├─config.py // Processing configuration parameters
| ├─device_adapter.py // Get cloud ID
| ├─local_adapter.py // Get local ID
| └─moxing_adapter.py // Parameter processing
├── train.py // training script ├── train.py // training script
├── eval.py // evaluation script ├── eval.py // evaluation script
├── export.py // export model in air format ├── export.py // export model in air format
├─default_config.yaml // Training parameter profile
└─train.py // Train net
``` ```
### [Script Parameters](#contents) ### [Script Parameters](#contents)

View File

@ -0,0 +1,58 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
data_url: ""
train_url: ""
checkpoint_url: ""
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: Ascend
enable_profiling: False
# ==============================================================================
# config of data
num_classes: 50
num_consumer: 4
get_npy: 1
get_mindrecord: 1
audio_path: "/cache/data"
npy_path: "/cache/data"
info_path: "/cache/data"
info_name: 'annotations_final.csv'
device_target: 'Ascend'
device_id: 0
mr_path: "/cache/data"
mr_name: ['train', 'val']
# config of music
pre_trained: False
lr: 0.0005
batch_size: 32
epoch_size: 10
loss_scale: 1024.0
mixed_precision: False
train_filename: 'train.mindrecord0'
val_filename: 'val.mindrecord0'
data_dir: "/cache/data"
keep_checkpoint_max: 10
save_step: 2000
checkpoint_path: "/cache/data/musicTagger"
prefix: 'MusicTagger'
model_name: 'MusicTagger-10_543.ckpt'
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
device_target: 'Target device type'
enable_profiling: 'Whether enable profiling while training, default: False'
file_name: 'output file name.'
file_format: 'file format'
---
device_target: ['Ascend', 'GPU', 'CPU']
file_format: ['AIR', 'ONNX', 'MINDIR']

View File

@ -0,0 +1,47 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
data_url: ""
train_url: ""
checkpoint_url: ""
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: Ascend
enable_profiling: False
# ==============================================================================
# What is the meaning of separating the two dictionaries in the original config file
pre_trained: False
lr: 0.0005
batch_size: 32
epoch_size: 10
loss_scale: 1024.0
num_consumer: 4
mixed_precision: False
train_filename: 'train.mindrecord0'
val_filename: 'val.mindrecord0'
data_dir: "/cache/data"
device_target: 'Ascend'
device_id: 0
keep_checkpoint_max: 10
save_step: 2000
checkpoint_path: "/cache/data/musicTagger"
prefix: 'MusicTagger'
model_name: 'MusicTagger-10_543.ckpt'
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
device_target: 'Target device type'
enable_profiling: 'Whether enable profiling while training, default: False'
file_name: 'output file name.'
file_format: 'file format'
---
device_target: ['Ascend', 'GPU', 'CPU']
file_format: ['AIR', 'ONNX', 'MINDIR']

View File

@ -17,15 +17,18 @@
python eval.py python eval.py
''' '''
import argparse
import numpy as np import numpy as np
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id
from src.musictagger import MusicTaggerCNN
from src.dataset import create_dataset
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.musictagger import MusicTaggerCNN
from src.config import music_cfg as cfg
from src.dataset import create_dataset
def calculate_auc(labels_list, preds_list): def calculate_auc(labels_list, preds_list):
@ -107,22 +110,15 @@ def validation(net, model_path, data_dir, filename, num_consumer, batch):
return auc return auc
if __name__ == "__main__": def modelarts_process():
parser = argparse.ArgumentParser(description='Evaluate model') pass
parser.add_argument('--device_id',
type=int,
help='device ID',
default=None)
args = parser.parse_args()
if args.device_id is not None: @moxing_wrapper(pre_process=modelarts_process)
context.set_context(device_target=cfg.device_target, def fcn4_eval():
mode=context.GRAPH_MODE, """
device_id=args.device_id) eval network
else: """
context.set_context(device_target=cfg.device_target, context.set_context(device_target=config.device_target, mode=context.GRAPH_MODE, device_id=get_device_id())
mode=context.GRAPH_MODE,
device_id=cfg.device_id)
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048], network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
kernel_size=[3, 3, 3, 3, 3], kernel_size=[3, 3, 3, 3, 3],
@ -130,8 +126,12 @@ if __name__ == "__main__":
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)], maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
has_bias=True) has_bias=True)
network.set_train(False) network.set_train(False)
auc_val = validation(network, cfg.checkpoint_path + "/" + cfg.model_name, cfg.data_dir, auc_val = validation(network, config.checkpoint_path + "/" + config.model_name, config.data_dir,
cfg.val_filename, cfg.num_consumer, cfg.batch_size) config.val_filename, config.num_consumer, config.batch_size)
print("=" * 10 + "Validation Performance" + "=" * 10) print("=" * 10 + "Validation Performance" + "=" * 10)
print("AUC: {:.5f}".format(auc_val)) print("AUC: {:.5f}".format(auc_val))
if __name__ == "__main__":
fcn4_eval()

View File

@ -21,8 +21,10 @@ import numpy as np
from mindspore.train.serialization import export from mindspore.train.serialization import export
from mindspore import Tensor from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.model_utils.config import config
from src.musictagger import MusicTaggerCNN from src.musictagger import MusicTaggerCNN
from src.config import music_cfg as cfg
if __name__ == "__main__": if __name__ == "__main__":
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048], network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
@ -30,11 +32,11 @@ if __name__ == "__main__":
padding=[0] * 5, padding=[0] * 5,
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)], maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
has_bias=True) has_bias=True)
param_dict = load_checkpoint(cfg.checkpoint_path + "/" + cfg.model_name) param_dict = load_checkpoint(config.checkpoint_path + "/" + config.model_name)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
input_data = np.random.uniform(0.0, 1.0, size=[1, 1, 96, 1366]).astype(np.float32) input_data = np.random.uniform(0.0, 1.0, size=[1, 1, 96, 1366]).astype(np.float32)
export(network, export(network,
Tensor(input_data), Tensor(input_data),
filename="{}/{}.air".format(cfg.checkpoint_path, filename="{}/{}.air".format(config.checkpoint_path,
cfg.model_name[:-5]), config.model_name[:-5]),
file_format="AIR") file_format="AIR")

View File

@ -14,5 +14,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
export DATA_PATH=$1
export CKPT_PATH=$2
export DEVICE_ID=$3
export SLOG_PRINT_TO_STDOUT=1 export SLOG_PRINT_TO_STDOUT=1
python ../eval.py --device_id 0 python ../eval.py --data_dir=$DATA_PATH --checkpoint_path=$CKPT_PATH --device_id=$DEVICE_ID > log_eval 2>&1 &

View File

@ -14,5 +14,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
export DATA_PATH=$1
export CKPT_PATH=$2
export DEVICE_ID=$3
export SLOG_PRINT_TO_STDOUT=1 export SLOG_PRINT_TO_STDOUT=1
python ../train.py --device_id 0 python ../train.py --data_dir=$DATA_PATH --checkpoint_path=$CKPT_PATH --device_id=$DEVICE_ID > log 2>&1 &

View File

@ -1,53 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py, eval.py
"""
from easydict import EasyDict as edict
data_cfg = edict({
'num_classes': 50,
'num_consumer': 4,
'get_npy': 1,
'get_mindrecord': 1,
'audio_path': "/dev/data/Music_Tagger_Data/fea/",
'npy_path': "/dev/data/Music_Tagger_Data/fea/",
'info_path': "/dev/data/Music_Tagger_Data/fea/",
'info_name': 'annotations_final.csv',
'device_target': 'Ascend',
'device_id': 0,
'mr_path': '/dev/data/Music_Tagger_Data/fea/',
'mr_name': ['train', 'val'],
})
music_cfg = edict({
'pre_trained': False,
'lr': 0.0005,
'batch_size': 32,
'epoch_size': 10,
'loss_scale': 1024.0,
'num_consumer': 4,
'mixed_precision': False,
'train_filename': 'train.mindrecord0',
'val_filename': 'val.mindrecord0',
'data_dir': '/dev/data/Music_Tagger_Data/fea/',
'device_target': 'Ascend',
'device_id': 0,
'keep_checkpoint_max': 10,
'save_step': 2000,
'checkpoint_path': '/dev/data/Music_Tagger_Data/model',
'prefix': 'MusicTagger',
'model_name': 'MusicTagger_3-50_543.ckpt',
})

View File

@ -0,0 +1,127 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from .config import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,122 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
from .config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
if config.enable_profiling:
profiler = Profiler()
run_func(*args, **kwargs)
if config.enable_profiling:
profiler.analyse()
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -21,7 +21,7 @@ import numpy as np
import librosa import librosa
from mindspore.mindrecord import FileWriter from mindspore.mindrecord import FileWriter
from mindspore import context from mindspore import context
from src.config import data_cfg as cfg from model_utils.config import config as cfg
def compute_melgram(audio_path, save_path='', filename='', save_npy=True): def compute_melgram(audio_path, save_path='', filename='', save_npy=True):

View File

@ -16,18 +16,27 @@
##############train models################# ##############train models#################
python train.py python train.py
''' '''
import argparse
from mindspore import context, nn from mindspore import context, nn
from mindspore.train import Model from mindspore.train import Model
from mindspore.common import set_seed from mindspore.common import set_seed
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from src.model_utils.device_adapter import get_device_id
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.dataset import create_dataset from src.dataset import create_dataset
from src.musictagger import MusicTaggerCNN from src.musictagger import MusicTaggerCNN
from src.loss import BCELoss from src.loss import BCELoss
from src.config import music_cfg as cfg
def modelarts_pre_process():
pass
# config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()), config.checkpoint_path)
@moxing_wrapper(pre_process=modelarts_pre_process)
def train(model, dataset_direct, filename, columns_list, num_consumer=4, def train(model, dataset_direct, filename, columns_list, num_consumer=4,
batch=16, epoch=50, save_checkpoint_steps=2172, keep_checkpoint_max=50, batch=16, epoch=50, save_checkpoint_steps=2172, keep_checkpoint_max=50,
prefix="model", directory='./'): prefix="model", directory='./'):
@ -43,67 +52,46 @@ def train(model, dataset_direct, filename, columns_list, num_consumer=4,
num_consumer) num_consumer)
model.train(epoch, model.train(epoch, data_train, callbacks=[ckpoint_cb, \
data_train, LossMonitor(per_print_times=181), TimeMonitor()], dataset_sink_mode=True)
callbacks=[
ckpoint_cb,
LossMonitor(per_print_times=181),
TimeMonitor()
],
dataset_sink_mode=True)
if __name__ == "__main__": if __name__ == "__main__":
set_seed(1) set_seed(1)
parser = argparse.ArgumentParser(description='Train model')
parser.add_argument('--device_id',
type=int,
help='device ID',
default=None)
args = parser.parse_args() context.set_context(device_target='Ascend', mode=context.GRAPH_MODE, device_id=get_device_id())
context.set_context(enable_auto_mixed_precision=config.mixed_precision)
if args.device_id is not None:
context.set_context(device_target='Ascend',
mode=context.GRAPH_MODE,
device_id=args.device_id)
else:
context.set_context(device_target='Ascend',
mode=context.GRAPH_MODE,
device_id=cfg.device_id)
context.set_context(enable_auto_mixed_precision=cfg.mixed_precision)
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048], network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
kernel_size=[3, 3, 3, 3, 3], kernel_size=[3, 3, 3, 3, 3],
padding=[0] * 5, padding=[0] * 5,
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)], maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
has_bias=True) has_bias=True)
if cfg.pre_trained: if config.pre_trained:
param_dict = load_checkpoint(cfg.checkpoint_path + '/' + param_dict = load_checkpoint(config.checkpoint_path + '/' +
cfg.model_name) config.model_name)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
net_loss = BCELoss() net_loss = BCELoss()
network.set_train(True) network.set_train(True)
net_opt = nn.Adam(params=network.trainable_params(), net_opt = nn.Adam(params=network.trainable_params(),
learning_rate=cfg.lr, learning_rate=config.lr,
loss_scale=cfg.loss_scale) loss_scale=config.loss_scale)
loss_scale_manager = FixedLossScaleManager(loss_scale=cfg.loss_scale, loss_scale_manager = FixedLossScaleManager(loss_scale=config.loss_scale,
drop_overflow_update=False) drop_overflow_update=False)
net_model = Model(network, net_loss, net_opt, loss_scale_manager=loss_scale_manager) net_model = Model(network, net_loss, net_opt, loss_scale_manager=loss_scale_manager)
train(model=net_model, train(model=net_model,
dataset_direct=cfg.data_dir, dataset_direct=config.data_dir,
filename=cfg.train_filename, filename=config.train_filename,
columns_list=['feature', 'label'], columns_list=['feature', 'label'],
num_consumer=cfg.num_consumer, num_consumer=config.num_consumer,
batch=cfg.batch_size, batch=config.batch_size,
epoch=cfg.epoch_size, epoch=config.epoch_size,
save_checkpoint_steps=cfg.save_step, save_checkpoint_steps=config.save_step,
keep_checkpoint_max=cfg.keep_checkpoint_max, keep_checkpoint_max=config.keep_checkpoint_max,
prefix=cfg.prefix, prefix=config.prefix,
directory=cfg.checkpoint_path + "_{}".format(cfg.device_id)) directory=config.checkpoint_path + "_{}".format(get_device_id()))
print("train success") print("train success")