clould
This commit is contained in:
parent
c1b91ff791
commit
7e60bbe181
|
@ -97,12 +97,18 @@ SLOG_PRINT_TO_STDOUT=1 python eval.py --device_id 0
|
||||||
│ ├──dataset.py // creating dataset
|
│ ├──dataset.py // creating dataset
|
||||||
│ ├──pre_process_data.py // pre-process dataset
|
│ ├──pre_process_data.py // pre-process dataset
|
||||||
│ ├──musictagger.py // googlenet architecture
|
│ ├──musictagger.py // googlenet architecture
|
||||||
│ ├──config.py // parameter configuration
|
|
||||||
│ ├──loss.py // loss function
|
│ ├──loss.py // loss function
|
||||||
│ ├──tag.txt // tag for each number
|
│ ├──tag.txt // tag for each number
|
||||||
|
| └─model_utils
|
||||||
|
| ├─config.py // Processing configuration parameters
|
||||||
|
| ├─device_adapter.py // Get cloud ID
|
||||||
|
| ├─local_adapter.py // Get local ID
|
||||||
|
| └─moxing_adapter.py // Parameter processing
|
||||||
├── train.py // training script
|
├── train.py // training script
|
||||||
├── eval.py // evaluation script
|
├── eval.py // evaluation script
|
||||||
├── export.py // export model in air format
|
├── export.py // export model in air format
|
||||||
|
├─default_config.yaml // Training parameter profile
|
||||||
|
└─train.py // Train net
|
||||||
```
|
```
|
||||||
|
|
||||||
### [Script Parameters](#contents)
|
### [Script Parameters](#contents)
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||||
|
enable_modelarts: False
|
||||||
|
data_url: ""
|
||||||
|
train_url: ""
|
||||||
|
checkpoint_url: ""
|
||||||
|
data_path: "/cache/data"
|
||||||
|
output_path: "/cache/train"
|
||||||
|
load_path: "/cache/checkpoint_path"
|
||||||
|
device_target: Ascend
|
||||||
|
enable_profiling: False
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# config of data
|
||||||
|
num_classes: 50
|
||||||
|
num_consumer: 4
|
||||||
|
get_npy: 1
|
||||||
|
get_mindrecord: 1
|
||||||
|
audio_path: "/cache/data"
|
||||||
|
npy_path: "/cache/data"
|
||||||
|
info_path: "/cache/data"
|
||||||
|
info_name: 'annotations_final.csv'
|
||||||
|
device_target: 'Ascend'
|
||||||
|
device_id: 0
|
||||||
|
mr_path: "/cache/data"
|
||||||
|
mr_name: ['train', 'val']
|
||||||
|
|
||||||
|
# config of music
|
||||||
|
pre_trained: False
|
||||||
|
lr: 0.0005
|
||||||
|
batch_size: 32
|
||||||
|
epoch_size: 10
|
||||||
|
loss_scale: 1024.0
|
||||||
|
mixed_precision: False
|
||||||
|
train_filename: 'train.mindrecord0'
|
||||||
|
val_filename: 'val.mindrecord0'
|
||||||
|
data_dir: "/cache/data"
|
||||||
|
keep_checkpoint_max: 10
|
||||||
|
save_step: 2000
|
||||||
|
checkpoint_path: "/cache/data/musicTagger"
|
||||||
|
prefix: 'MusicTagger'
|
||||||
|
model_name: 'MusicTagger-10_543.ckpt'
|
||||||
|
|
||||||
|
---
|
||||||
|
# Config description for each option
|
||||||
|
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||||
|
data_url: 'Dataset url for obs'
|
||||||
|
train_url: 'Training output url for obs'
|
||||||
|
data_path: 'Dataset path for local'
|
||||||
|
output_path: 'Training output path for local'
|
||||||
|
|
||||||
|
device_target: 'Target device type'
|
||||||
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
file_name: 'output file name.'
|
||||||
|
file_format: 'file format'
|
||||||
|
|
||||||
|
---
|
||||||
|
device_target: ['Ascend', 'GPU', 'CPU']
|
||||||
|
file_format: ['AIR', 'ONNX', 'MINDIR']
|
|
@ -0,0 +1,47 @@
|
||||||
|
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||||
|
enable_modelarts: False
|
||||||
|
data_url: ""
|
||||||
|
train_url: ""
|
||||||
|
checkpoint_url: ""
|
||||||
|
data_path: "/cache/data"
|
||||||
|
output_path: "/cache/train"
|
||||||
|
load_path: "/cache/checkpoint_path"
|
||||||
|
device_target: Ascend
|
||||||
|
enable_profiling: False
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# What is the meaning of separating the two dictionaries in the original config file?
|
||||||
|
pre_trained: False
|
||||||
|
lr: 0.0005
|
||||||
|
batch_size: 32
|
||||||
|
epoch_size: 10
|
||||||
|
loss_scale: 1024.0
|
||||||
|
num_consumer: 4
|
||||||
|
mixed_precision: False
|
||||||
|
train_filename: 'train.mindrecord0'
|
||||||
|
val_filename: 'val.mindrecord0'
|
||||||
|
data_dir: "/cache/data"
|
||||||
|
device_target: 'Ascend'
|
||||||
|
device_id: 0
|
||||||
|
keep_checkpoint_max: 10
|
||||||
|
save_step: 2000
|
||||||
|
checkpoint_path: "/cache/data/musicTagger"
|
||||||
|
prefix: 'MusicTagger'
|
||||||
|
model_name: 'MusicTagger-10_543.ckpt'
|
||||||
|
|
||||||
|
---
|
||||||
|
# Config description for each option
|
||||||
|
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||||
|
data_url: 'Dataset url for obs'
|
||||||
|
train_url: 'Training output url for obs'
|
||||||
|
data_path: 'Dataset path for local'
|
||||||
|
output_path: 'Training output path for local'
|
||||||
|
|
||||||
|
device_target: 'Target device type'
|
||||||
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
file_name: 'output file name.'
|
||||||
|
file_format: 'file format'
|
||||||
|
|
||||||
|
---
|
||||||
|
device_target: ['Ascend', 'GPU', 'CPU']
|
||||||
|
file_format: ['AIR', 'ONNX', 'MINDIR']
|
|
@ -17,15 +17,18 @@
|
||||||
python eval.py
|
python eval.py
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import argparse
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from src.model_utils.config import config
|
||||||
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
|
from src.model_utils.device_adapter import get_device_id
|
||||||
|
from src.musictagger import MusicTaggerCNN
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from src.musictagger import MusicTaggerCNN
|
|
||||||
from src.config import music_cfg as cfg
|
|
||||||
from src.dataset import create_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_auc(labels_list, preds_list):
|
def calculate_auc(labels_list, preds_list):
|
||||||
|
@ -107,22 +110,15 @@ def validation(net, model_path, data_dir, filename, num_consumer, batch):
|
||||||
return auc
|
return auc
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def modelarts_process():
|
||||||
parser = argparse.ArgumentParser(description='Evaluate model')
|
pass
|
||||||
parser.add_argument('--device_id',
|
|
||||||
type=int,
|
|
||||||
help='device ID',
|
|
||||||
default=None)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.device_id is not None:
|
@moxing_wrapper(pre_process=modelarts_process)
|
||||||
context.set_context(device_target=cfg.device_target,
|
def fcn4_eval():
|
||||||
mode=context.GRAPH_MODE,
|
"""
|
||||||
device_id=args.device_id)
|
eval network
|
||||||
else:
|
"""
|
||||||
context.set_context(device_target=cfg.device_target,
|
context.set_context(device_target=config.device_target, mode=context.GRAPH_MODE, device_id=get_device_id())
|
||||||
mode=context.GRAPH_MODE,
|
|
||||||
device_id=cfg.device_id)
|
|
||||||
|
|
||||||
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
||||||
kernel_size=[3, 3, 3, 3, 3],
|
kernel_size=[3, 3, 3, 3, 3],
|
||||||
|
@ -130,8 +126,12 @@ if __name__ == "__main__":
|
||||||
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
||||||
has_bias=True)
|
has_bias=True)
|
||||||
network.set_train(False)
|
network.set_train(False)
|
||||||
auc_val = validation(network, cfg.checkpoint_path + "/" + cfg.model_name, cfg.data_dir,
|
auc_val = validation(network, config.checkpoint_path + "/" + config.model_name, config.data_dir,
|
||||||
cfg.val_filename, cfg.num_consumer, cfg.batch_size)
|
config.val_filename, config.num_consumer, config.batch_size)
|
||||||
|
|
||||||
print("=" * 10 + "Validation Performance" + "=" * 10)
|
print("=" * 10 + "Validation Performance" + "=" * 10)
|
||||||
print("AUC: {:.5f}".format(auc_val))
|
print("AUC: {:.5f}".format(auc_val))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fcn4_eval()
|
||||||
|
|
|
@ -21,8 +21,10 @@ import numpy as np
|
||||||
from mindspore.train.serialization import export
|
from mindspore.train.serialization import export
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
from src.model_utils.config import config
|
||||||
from src.musictagger import MusicTaggerCNN
|
from src.musictagger import MusicTaggerCNN
|
||||||
from src.config import music_cfg as cfg
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
||||||
|
@ -30,11 +32,11 @@ if __name__ == "__main__":
|
||||||
padding=[0] * 5,
|
padding=[0] * 5,
|
||||||
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
||||||
has_bias=True)
|
has_bias=True)
|
||||||
param_dict = load_checkpoint(cfg.checkpoint_path + "/" + cfg.model_name)
|
param_dict = load_checkpoint(config.checkpoint_path + "/" + config.model_name)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
input_data = np.random.uniform(0.0, 1.0, size=[1, 1, 96, 1366]).astype(np.float32)
|
input_data = np.random.uniform(0.0, 1.0, size=[1, 1, 96, 1366]).astype(np.float32)
|
||||||
export(network,
|
export(network,
|
||||||
Tensor(input_data),
|
Tensor(input_data),
|
||||||
filename="{}/{}.air".format(cfg.checkpoint_path,
|
filename="{}/{}.air".format(config.checkpoint_path,
|
||||||
cfg.model_name[:-5]),
|
config.model_name[:-5]),
|
||||||
file_format="AIR")
|
file_format="AIR")
|
||||||
|
|
|
@ -14,5 +14,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
export DATA_PATH=$1
|
||||||
|
export CKPT_PATH=$2
|
||||||
|
export DEVICE_ID=$3
|
||||||
export SLOG_PRINT_TO_STDOUT=1
|
export SLOG_PRINT_TO_STDOUT=1
|
||||||
python ../eval.py --device_id 0
|
python ../eval.py --data_dir=$DATA_PATH --checkpoint_path=$CKPT_PATH --device_id=$DEVICE_ID > log_eval 2>&1 &
|
||||||
|
|
|
@ -14,5 +14,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
export DATA_PATH=$1
|
||||||
|
export CKPT_PATH=$2
|
||||||
|
export DEVICE_ID=$3
|
||||||
export SLOG_PRINT_TO_STDOUT=1
|
export SLOG_PRINT_TO_STDOUT=1
|
||||||
python ../train.py --device_id 0
|
python ../train.py --data_dir=$DATA_PATH --checkpoint_path=$CKPT_PATH --device_id=$DEVICE_ID > log 2>&1 &
|
||||||
|
|
|
@ -1,53 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""
|
|
||||||
network config setting, will be used in train.py, eval.py
|
|
||||||
"""
|
|
||||||
from easydict import EasyDict as edict
|
|
||||||
|
|
||||||
data_cfg = edict({
|
|
||||||
'num_classes': 50,
|
|
||||||
'num_consumer': 4,
|
|
||||||
'get_npy': 1,
|
|
||||||
'get_mindrecord': 1,
|
|
||||||
'audio_path': "/dev/data/Music_Tagger_Data/fea/",
|
|
||||||
'npy_path': "/dev/data/Music_Tagger_Data/fea/",
|
|
||||||
'info_path': "/dev/data/Music_Tagger_Data/fea/",
|
|
||||||
'info_name': 'annotations_final.csv',
|
|
||||||
'device_target': 'Ascend',
|
|
||||||
'device_id': 0,
|
|
||||||
'mr_path': '/dev/data/Music_Tagger_Data/fea/',
|
|
||||||
'mr_name': ['train', 'val'],
|
|
||||||
})
|
|
||||||
|
|
||||||
music_cfg = edict({
|
|
||||||
'pre_trained': False,
|
|
||||||
'lr': 0.0005,
|
|
||||||
'batch_size': 32,
|
|
||||||
'epoch_size': 10,
|
|
||||||
'loss_scale': 1024.0,
|
|
||||||
'num_consumer': 4,
|
|
||||||
'mixed_precision': False,
|
|
||||||
'train_filename': 'train.mindrecord0',
|
|
||||||
'val_filename': 'val.mindrecord0',
|
|
||||||
'data_dir': '/dev/data/Music_Tagger_Data/fea/',
|
|
||||||
'device_target': 'Ascend',
|
|
||||||
'device_id': 0,
|
|
||||||
'keep_checkpoint_max': 10,
|
|
||||||
'save_step': 2000,
|
|
||||||
'checkpoint_path': '/dev/data/Music_Tagger_Data/model',
|
|
||||||
'prefix': 'MusicTagger',
|
|
||||||
'model_name': 'MusicTagger_3-50_543.ckpt',
|
|
||||||
})
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Parse arguments"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import ast
|
||||||
|
import argparse
|
||||||
|
from pprint import pprint, pformat
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""
|
||||||
|
Configuration namespace. Convert dictionary to members.
|
||||||
|
"""
|
||||||
|
def __init__(self, cfg_dict):
|
||||||
|
for k, v in cfg_dict.items():
|
||||||
|
if isinstance(v, (list, tuple)):
|
||||||
|
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||||
|
else:
|
||||||
|
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return pformat(self.__dict__)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||||
|
"""
|
||||||
|
Parse command line arguments to the configuration according to the default yaml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parser: Parent parser.
|
||||||
|
cfg: Base configuration.
|
||||||
|
helper: Helper description.
|
||||||
|
cfg_path: Path to the default yaml config.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||||
|
parents=[parser])
|
||||||
|
helper = {} if helper is None else helper
|
||||||
|
choices = {} if choices is None else choices
|
||||||
|
for item in cfg:
|
||||||
|
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||||
|
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||||
|
choice = choices[item] if item in choices else None
|
||||||
|
if isinstance(cfg[item], bool):
|
||||||
|
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||||
|
help=help_description)
|
||||||
|
else:
|
||||||
|
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||||
|
help=help_description)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def parse_yaml(yaml_path):
|
||||||
|
"""
|
||||||
|
Parse the yaml config file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
yaml_path: Path to the yaml config.
|
||||||
|
"""
|
||||||
|
with open(yaml_path, 'r') as fin:
|
||||||
|
try:
|
||||||
|
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||||
|
cfgs = [x for x in cfgs]
|
||||||
|
if len(cfgs) == 1:
|
||||||
|
cfg_helper = {}
|
||||||
|
cfg = cfgs[0]
|
||||||
|
cfg_choices = {}
|
||||||
|
elif len(cfgs) == 2:
|
||||||
|
cfg, cfg_helper = cfgs
|
||||||
|
cfg_choices = {}
|
||||||
|
elif len(cfgs) == 3:
|
||||||
|
cfg, cfg_helper, cfg_choices = cfgs
|
||||||
|
else:
|
||||||
|
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||||
|
print(cfg_helper)
|
||||||
|
except:
|
||||||
|
raise ValueError("Failed to parse yaml")
|
||||||
|
return cfg, cfg_helper, cfg_choices
|
||||||
|
|
||||||
|
|
||||||
|
def merge(args, cfg):
|
||||||
|
"""
|
||||||
|
Merge the base config from yaml file and command line arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Command line arguments.
|
||||||
|
cfg: Base configuration.
|
||||||
|
"""
|
||||||
|
args_var = vars(args)
|
||||||
|
for item in args_var:
|
||||||
|
cfg[item] = args_var[item]
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
"""
|
||||||
|
Get Config according to the yaml file and cli arguments.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||||
|
help="Config file path")
|
||||||
|
path_args, _ = parser.parse_known_args()
|
||||||
|
default, helper, choices = parse_yaml(path_args.config_path)
|
||||||
|
pprint(default)
|
||||||
|
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||||
|
final_config = merge(args, default)
|
||||||
|
return Config(final_config)
|
||||||
|
|
||||||
|
config = get_config()
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Device adapter for ModelArts"""
|
||||||
|
|
||||||
|
from .config import config
|
||||||
|
|
||||||
|
if config.enable_modelarts:
|
||||||
|
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||||
|
else:
|
||||||
|
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||||
|
]
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Local adapter"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
def get_device_id():
|
||||||
|
device_id = os.getenv('DEVICE_ID', '0')
|
||||||
|
return int(device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_num():
|
||||||
|
device_num = os.getenv('RANK_SIZE', '1')
|
||||||
|
return int(device_num)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank_id():
|
||||||
|
global_rank_id = os.getenv('RANK_ID', '0')
|
||||||
|
return int(global_rank_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_id():
|
||||||
|
return "Local Job"
|
|
@ -0,0 +1,122 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Moxing adapter for ModelArts"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import functools
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.profiler import Profiler
|
||||||
|
from .config import config
|
||||||
|
|
||||||
|
_global_sync_count = 0
|
||||||
|
|
||||||
|
def get_device_id():
|
||||||
|
device_id = os.getenv('DEVICE_ID', '0')
|
||||||
|
return int(device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_num():
|
||||||
|
device_num = os.getenv('RANK_SIZE', '1')
|
||||||
|
return int(device_num)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank_id():
|
||||||
|
global_rank_id = os.getenv('RANK_ID', '0')
|
||||||
|
return int(global_rank_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_id():
|
||||||
|
job_id = os.getenv('JOB_ID')
|
||||||
|
job_id = job_id if job_id != "" else "default"
|
||||||
|
return job_id
|
||||||
|
|
||||||
|
def sync_data(from_path, to_path):
|
||||||
|
"""
|
||||||
|
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||||
|
Upload data from local directory to remote obs in contrast.
|
||||||
|
"""
|
||||||
|
import moxing as mox
|
||||||
|
import time
|
||||||
|
global _global_sync_count
|
||||||
|
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||||
|
_global_sync_count += 1
|
||||||
|
|
||||||
|
# Each server contains 8 devices as most.
|
||||||
|
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||||
|
print("from path: ", from_path)
|
||||||
|
print("to path: ", to_path)
|
||||||
|
mox.file.copy_parallel(from_path, to_path)
|
||||||
|
print("===finish data synchronization===")
|
||||||
|
try:
|
||||||
|
os.mknod(sync_lock)
|
||||||
|
except IOError:
|
||||||
|
pass
|
||||||
|
print("===save flag===")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if os.path.exists(sync_lock):
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||||
|
|
||||||
|
|
||||||
|
def moxing_wrapper(pre_process=None, post_process=None):
|
||||||
|
"""
|
||||||
|
Moxing wrapper to download dataset and upload outputs.
|
||||||
|
"""
|
||||||
|
def wrapper(run_func):
|
||||||
|
@functools.wraps(run_func)
|
||||||
|
def wrapped_func(*args, **kwargs):
|
||||||
|
# Download data from data_url
|
||||||
|
if config.enable_modelarts:
|
||||||
|
if config.data_url:
|
||||||
|
sync_data(config.data_url, config.data_path)
|
||||||
|
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||||
|
if config.checkpoint_url:
|
||||||
|
sync_data(config.checkpoint_url, config.load_path)
|
||||||
|
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||||
|
if config.train_url:
|
||||||
|
sync_data(config.train_url, config.output_path)
|
||||||
|
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||||
|
|
||||||
|
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||||
|
config.device_num = get_device_num()
|
||||||
|
config.device_id = get_device_id()
|
||||||
|
if not os.path.exists(config.output_path):
|
||||||
|
os.makedirs(config.output_path)
|
||||||
|
|
||||||
|
if pre_process:
|
||||||
|
pre_process()
|
||||||
|
|
||||||
|
if config.enable_profiling:
|
||||||
|
profiler = Profiler()
|
||||||
|
|
||||||
|
run_func(*args, **kwargs)
|
||||||
|
|
||||||
|
if config.enable_profiling:
|
||||||
|
profiler.analyse()
|
||||||
|
|
||||||
|
# Upload data to train_url
|
||||||
|
if config.enable_modelarts:
|
||||||
|
if post_process:
|
||||||
|
post_process()
|
||||||
|
|
||||||
|
if config.train_url:
|
||||||
|
print("Start to copy output directory")
|
||||||
|
sync_data(config.output_path, config.train_url)
|
||||||
|
return wrapped_func
|
||||||
|
return wrapper
|
|
@ -21,7 +21,7 @@ import numpy as np
|
||||||
import librosa
|
import librosa
|
||||||
from mindspore.mindrecord import FileWriter
|
from mindspore.mindrecord import FileWriter
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from src.config import data_cfg as cfg
|
from model_utils.config import config as cfg
|
||||||
|
|
||||||
|
|
||||||
def compute_melgram(audio_path, save_path='', filename='', save_npy=True):
|
def compute_melgram(audio_path, save_path='', filename='', save_npy=True):
|
||||||
|
|
|
@ -16,18 +16,27 @@
|
||||||
##############train models#################
|
##############train models#################
|
||||||
python train.py
|
python train.py
|
||||||
'''
|
'''
|
||||||
import argparse
|
|
||||||
from mindspore import context, nn
|
from mindspore import context, nn
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
from mindspore.common import set_seed
|
from mindspore.common import set_seed
|
||||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
|
|
||||||
|
from src.model_utils.device_adapter import get_device_id
|
||||||
|
from src.model_utils.config import config
|
||||||
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
from src.musictagger import MusicTaggerCNN
|
from src.musictagger import MusicTaggerCNN
|
||||||
from src.loss import BCELoss
|
from src.loss import BCELoss
|
||||||
from src.config import music_cfg as cfg
|
|
||||||
|
|
||||||
|
|
||||||
|
def modelarts_pre_process():
|
||||||
|
pass
|
||||||
|
# config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()), config.checkpoint_path)
|
||||||
|
|
||||||
|
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||||
def train(model, dataset_direct, filename, columns_list, num_consumer=4,
|
def train(model, dataset_direct, filename, columns_list, num_consumer=4,
|
||||||
batch=16, epoch=50, save_checkpoint_steps=2172, keep_checkpoint_max=50,
|
batch=16, epoch=50, save_checkpoint_steps=2172, keep_checkpoint_max=50,
|
||||||
prefix="model", directory='./'):
|
prefix="model", directory='./'):
|
||||||
|
@ -43,67 +52,46 @@ def train(model, dataset_direct, filename, columns_list, num_consumer=4,
|
||||||
num_consumer)
|
num_consumer)
|
||||||
|
|
||||||
|
|
||||||
model.train(epoch,
|
model.train(epoch, data_train, callbacks=[ckpoint_cb, \
|
||||||
data_train,
|
LossMonitor(per_print_times=181), TimeMonitor()], dataset_sink_mode=True)
|
||||||
callbacks=[
|
|
||||||
ckpoint_cb,
|
|
||||||
LossMonitor(per_print_times=181),
|
|
||||||
TimeMonitor()
|
|
||||||
],
|
|
||||||
dataset_sink_mode=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
parser = argparse.ArgumentParser(description='Train model')
|
|
||||||
parser.add_argument('--device_id',
|
|
||||||
type=int,
|
|
||||||
help='device ID',
|
|
||||||
default=None)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
context.set_context(device_target='Ascend', mode=context.GRAPH_MODE, device_id=get_device_id())
|
||||||
|
context.set_context(enable_auto_mixed_precision=config.mixed_precision)
|
||||||
if args.device_id is not None:
|
|
||||||
context.set_context(device_target='Ascend',
|
|
||||||
mode=context.GRAPH_MODE,
|
|
||||||
device_id=args.device_id)
|
|
||||||
else:
|
|
||||||
context.set_context(device_target='Ascend',
|
|
||||||
mode=context.GRAPH_MODE,
|
|
||||||
device_id=cfg.device_id)
|
|
||||||
|
|
||||||
context.set_context(enable_auto_mixed_precision=cfg.mixed_precision)
|
|
||||||
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
||||||
kernel_size=[3, 3, 3, 3, 3],
|
kernel_size=[3, 3, 3, 3, 3],
|
||||||
padding=[0] * 5,
|
padding=[0] * 5,
|
||||||
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
||||||
has_bias=True)
|
has_bias=True)
|
||||||
|
|
||||||
if cfg.pre_trained:
|
if config.pre_trained:
|
||||||
param_dict = load_checkpoint(cfg.checkpoint_path + '/' +
|
param_dict = load_checkpoint(config.checkpoint_path + '/' +
|
||||||
cfg.model_name)
|
config.model_name)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
net_loss = BCELoss()
|
net_loss = BCELoss()
|
||||||
|
|
||||||
network.set_train(True)
|
network.set_train(True)
|
||||||
net_opt = nn.Adam(params=network.trainable_params(),
|
net_opt = nn.Adam(params=network.trainable_params(),
|
||||||
learning_rate=cfg.lr,
|
learning_rate=config.lr,
|
||||||
loss_scale=cfg.loss_scale)
|
loss_scale=config.loss_scale)
|
||||||
|
|
||||||
loss_scale_manager = FixedLossScaleManager(loss_scale=cfg.loss_scale,
|
loss_scale_manager = FixedLossScaleManager(loss_scale=config.loss_scale,
|
||||||
drop_overflow_update=False)
|
drop_overflow_update=False)
|
||||||
net_model = Model(network, net_loss, net_opt, loss_scale_manager=loss_scale_manager)
|
net_model = Model(network, net_loss, net_opt, loss_scale_manager=loss_scale_manager)
|
||||||
|
|
||||||
train(model=net_model,
|
train(model=net_model,
|
||||||
dataset_direct=cfg.data_dir,
|
dataset_direct=config.data_dir,
|
||||||
filename=cfg.train_filename,
|
filename=config.train_filename,
|
||||||
columns_list=['feature', 'label'],
|
columns_list=['feature', 'label'],
|
||||||
num_consumer=cfg.num_consumer,
|
num_consumer=config.num_consumer,
|
||||||
batch=cfg.batch_size,
|
batch=config.batch_size,
|
||||||
epoch=cfg.epoch_size,
|
epoch=config.epoch_size,
|
||||||
save_checkpoint_steps=cfg.save_step,
|
save_checkpoint_steps=config.save_step,
|
||||||
keep_checkpoint_max=cfg.keep_checkpoint_max,
|
keep_checkpoint_max=config.keep_checkpoint_max,
|
||||||
prefix=cfg.prefix,
|
prefix=config.prefix,
|
||||||
directory=cfg.checkpoint_path + "_{}".format(cfg.device_id))
|
directory=config.checkpoint_path + "_{}".format(get_device_id()))
|
||||||
print("train success")
|
print("train success")
|
||||||
|
|
Loading…
Reference in New Issue