forked from mindspore-Ecosystem/mindspore
!6759 Add music auto tagging network in modelzoo
Merge pull request !6759 from jiangzhenguang/add_music_auto_tagging
This commit is contained in:
commit
2b6a88e09e
|
@ -0,0 +1,203 @@
|
|||
# Contents
|
||||
|
||||
- [Music Auto Tagging Description](#fcn-4-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
|
||||
# [Music Auto Tagging Description](#contents)
|
||||
|
||||
This repository provides a script and recipe to train the Music Auto Tagging model to achieve state-of-the-art accuracy.
|
||||
|
||||
[Paper](https://arxiv.org/abs/1606.00298): `"Keunwoo Choi, George Fazekas, and Mark Sandler, “Automatic tagging using deep convolutional neural networks,” in International Society of Music Information Retrieval Conference. ISMIR, 2016."
|
||||
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
Music Auto Tagging is a convolutional neural network architecture, its name Music Auto Tagging comes from the fact that it has 4 layers. Its layers consists of Convolutional layers, Max Pooling layers, Activation layers, Fully connected layers.
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
## Mixed Precision
|
||||
|
||||
The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
|
||||
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
|
||||
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend
|
||||
- If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||
|
||||
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
### 1. Download and preprocess the dataset
|
||||
|
||||
1. down load the classification dataset (for instance, MagnaTagATune Dataset, Million Song Dataset, etc)
|
||||
2. Extract the dataset
|
||||
3. The information file of each clip should contain the label and path. Please refer to the annotations_final.csv in MagnaTagATune Dataset.
|
||||
4. The provided pre-processing script use MagnaTagATune Dataset as an example. Please modify the code accprding to your own need.
|
||||
|
||||
### 2. setup parameters (src/config.py)
|
||||
|
||||
### 3. Train
|
||||
|
||||
after having your dataset, first convert the audio clip into mindrecord dataset by using the following codes
|
||||
```shell
|
||||
python pre_process_data.py --device_id 0
|
||||
```
|
||||
|
||||
Then, you can start training the model by using the following codes
|
||||
```shell
|
||||
SLOG_PRINT_TO_STDOUT=1 python train.py --device_id 0
|
||||
```
|
||||
|
||||
### 4. Test
|
||||
|
||||
Then you can test your model
|
||||
```shell
|
||||
SLOG_PRINT_TO_STDOUT=1 python eval.py --device_id 0
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```
|
||||
├── model_zoo
|
||||
├── README.md // descriptions about all the models
|
||||
├── music_auto_tagging
|
||||
├── README.md // descriptions about googlenet
|
||||
├── scripts
|
||||
│ ├──run_train.sh // shell script for distributed on Ascend
|
||||
│ ├──run_eval.sh // shell script for evaluation on Ascend
|
||||
│ ├──run_process_data.sh // shell script for convert audio clips to mindrecord
|
||||
├── src
|
||||
│ ├──dataset.py // creating dataset
|
||||
│ ├──pre_process_data.py // pre-process dataset
|
||||
│ ├──musictagger.py // googlenet architecture
|
||||
│ ├──config.py // parameter configuration
|
||||
│ ├──loss.py // loss function
|
||||
│ ├──tag.txt // tag for each number
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
├── export.py // export model in air format
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py
|
||||
|
||||
- config for Music Auto tagging
|
||||
|
||||
```python
|
||||
|
||||
'num_classes': 50 # number of tagging classes
|
||||
'num_consumer': 4 # file number for mindrecord
|
||||
'get_npy': 1 # mode for converting to npy, default 1 in this case
|
||||
'get_mindrecord': 1 # mode for converting npy file into mindrecord file,default 1 in this case
|
||||
'audio_path': "/dev/data/Music_Tagger_Data/fea/" # path to audio clips
|
||||
'npy_path': "/dev/data/Music_Tagger_Data/fea/" # path to numpy
|
||||
'info_path': "/dev/data/Music_Tagger_Data/fea/" # path to info_name, which provide the label of each audio clips
|
||||
'info_name': 'annotations_final.csv' # info_name
|
||||
'device_target': 'Ascend' # device running the program
|
||||
'device_id': 0 # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training
|
||||
'mr_path': '/dev/data/Music_Tagger_Data/fea/' # path to mindrecord
|
||||
'mr_name': ['train', 'val'] # mindrecord name
|
||||
|
||||
'pre_trained': False # whether training based on the pre-trained model
|
||||
'lr': 0.0005 # learning rate
|
||||
'batch_size': 32 # training batch size
|
||||
'epoch_size': 10 # total training epochs
|
||||
'loss_scale': 1024.0 # loss scale
|
||||
'num_consumer': 4 # file number for mindrecord
|
||||
'mixed_precision': False # if use mix precision calculation
|
||||
'train_filename': 'train.mindrecord0' # file name of the train mindrecord data
|
||||
'val_filename': 'val.mindrecord0' # file name of the evaluation mindrecord data
|
||||
'data_dir': '/dev/data/Music_Tagger_Data/fea/' # directory of mindrecord data
|
||||
'device_target': 'Ascend' # device running the program
|
||||
'device_id': 0, # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training
|
||||
'keep_checkpoint_max': 10, # only keep the last keep_checkpoint_max checkpoint
|
||||
'save_step': 2000, # steps for saving checkpoint
|
||||
'checkpoint_path': '/dev/data/Music_Tagger_Data/model/', # the absolute full path to save the checkpoint file
|
||||
'prefix': 'MusicTagger', # prefix of checkpoint
|
||||
'model_name': 'MusicTagger_3-50_543.ckpt', # checkpoint name
|
||||
```
|
||||
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```
|
||||
python train.py > train.log 2>&1 &
|
||||
```
|
||||
|
||||
The python command above will run in the background, you can view the results through the file `train.log`.
|
||||
|
||||
After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
|
||||
|
||||
```
|
||||
# grep "loss is " train.log
|
||||
epoch: 1 step: 100, loss is 0.23264095
|
||||
epoch: 1 step: 200, loss is 0.2013525
|
||||
...
|
||||
```
|
||||
|
||||
The model checkpoint will be saved in the set directory.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
||||
|
||||
# [Model Description](#contents)
|
||||
## [Performance](#contents)
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | FCN-4 |
|
||||
| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G |
|
||||
| uploaded Date | 09/11/2020 (month/day/year) |
|
||||
| MindSpore Version | r0.7.0 |
|
||||
| Training Parameters | epoch=10, steps=534, batch_size = 32, lr=0.005 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | Binary cross entropy |
|
||||
| outputs | probability |
|
||||
| Loss | AUC 0.909 |
|
||||
| Speed | 1pc: 160 samples/sec; |
|
||||
| Total time | 1pc: 20 mins; |
|
||||
| Checkpoint for Fine tuning | 198.73M(.ckpt file) |
|
||||
| Scripts | [music_auto_tagging script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/audio/music_auto_tagging) |
|
||||
|
||||
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,137 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''
|
||||
##############evaluate trained models#################
|
||||
python eval.py
|
||||
'''
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.musictagger import MusicTaggerCNN
|
||||
from src.config import music_cfg as cfg
|
||||
from src.dataset import create_dataset
|
||||
|
||||
|
||||
def calculate_auc(labels_list, preds_list):
|
||||
"""
|
||||
The AUC calculation function
|
||||
Input:
|
||||
labels_list: list of true label
|
||||
preds_list: list of predicted label
|
||||
Outputs
|
||||
Float, means of AUC
|
||||
"""
|
||||
auc = []
|
||||
n_bins = labels_list.shape[0] // 2
|
||||
if labels_list.ndim == 1:
|
||||
labels_list = labels_list.reshape(-1, 1)
|
||||
preds_list = preds_list.reshape(-1, 1)
|
||||
for i in range(labels_list.shape[1]):
|
||||
labels = labels_list[:, i]
|
||||
preds = preds_list[:, i]
|
||||
postive_len = labels.sum()
|
||||
negative_len = labels.shape[0] - postive_len
|
||||
total_case = postive_len * negative_len
|
||||
positive_histogram = np.zeros((n_bins))
|
||||
negative_histogram = np.zeros((n_bins))
|
||||
bin_width = 1.0 / n_bins
|
||||
|
||||
for j, _ in enumerate(labels):
|
||||
nth_bin = int(preds[j] // bin_width)
|
||||
if labels[j]:
|
||||
positive_histogram[nth_bin] = positive_histogram[nth_bin] + 1
|
||||
else:
|
||||
negative_histogram[nth_bin] = negative_histogram[nth_bin] + 1
|
||||
|
||||
accumulated_negative = 0
|
||||
satisfied_pair = 0
|
||||
for k in range(n_bins):
|
||||
satisfied_pair += (
|
||||
positive_histogram[k] * accumulated_negative +
|
||||
positive_histogram[k] * negative_histogram[k] * 0.5)
|
||||
accumulated_negative += negative_histogram[k]
|
||||
auc.append(satisfied_pair / total_case)
|
||||
|
||||
return np.mean(auc)
|
||||
|
||||
|
||||
def val(net, data_dir, filename, num_consumer=4, batch=32):
|
||||
"""
|
||||
Validation function, estimate the performance of trained model
|
||||
|
||||
Input:
|
||||
net: the trained neural network
|
||||
data_dir: path to the validation dataset
|
||||
filename: name of the validation dataset
|
||||
num_consumer: split number of validation dataset
|
||||
batch: validation batch size
|
||||
Outputs
|
||||
Float, AUC
|
||||
"""
|
||||
data_train = create_dataset(data_dir, filename, 32, ['feature', 'label'],
|
||||
num_consumer)
|
||||
data_train = data_train.create_tuple_iterator()
|
||||
res_pred = []
|
||||
res_true = []
|
||||
for data, label in data_train:
|
||||
x = net(Tensor(data, dtype=mstype.float32))
|
||||
res_pred.append(x.asnumpy())
|
||||
res_true.append(label.asnumpy())
|
||||
res_pred = np.concatenate(res_pred, axis=0)
|
||||
res_true = np.concatenate(res_true, axis=0)
|
||||
auc = calculate_auc(res_true, res_pred)
|
||||
return auc
|
||||
|
||||
|
||||
def validation(net, model_path, data_dir, filename, num_consumer, batch):
|
||||
param_dict = load_checkpoint(model_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
auc = val(net, data_dir, filename, num_consumer, batch)
|
||||
return auc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Evaluate model')
|
||||
parser.add_argument('--device_id',
|
||||
type=int,
|
||||
help='device ID',
|
||||
default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.device_id is not None:
|
||||
context.set_context(device_target=cfg.device_target,
|
||||
mode=context.GRAPH_MODE,
|
||||
device_id=args.device_id)
|
||||
else:
|
||||
context.set_context(device_target=cfg.device_target,
|
||||
mode=context.GRAPH_MODE,
|
||||
device_id=cfg.device_id)
|
||||
|
||||
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
||||
kernel_size=[3, 3, 3, 3, 3],
|
||||
padding=[0] * 5,
|
||||
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
||||
has_bias=True)
|
||||
network.set_train(False)
|
||||
auc_val = validation(network, cfg.checkpoint_path + "/" + cfg.model_name, cfg.data_dir,
|
||||
cfg.val_filename, cfg.num_consumer, cfg.batch_size)
|
||||
|
||||
print("=" * 10 + "Validation Peformance" + "=" * 10)
|
||||
print("AUC: {:.5f}".format(auc_val))
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''
|
||||
##############evaluate trained models#################
|
||||
python export.py
|
||||
'''
|
||||
|
||||
import numpy as np
|
||||
from mindspore.train.serialization import export
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.musictagger import MusicTaggerCNN
|
||||
from src.config import music_cfg as cfg
|
||||
|
||||
if __name__ == "__main__":
|
||||
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
||||
kernel_size=[3, 3, 3, 3, 3],
|
||||
padding=[0] * 5,
|
||||
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
||||
has_bias=True)
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path + "/" + cfg.model_name)
|
||||
load_param_into_net(network, param_dict)
|
||||
input_data = np.random.uniform(0.0, 1.0, size=[1, 1, 96, 1366]).astype(np.float32)
|
||||
export(network,
|
||||
Tensor(input_data),
|
||||
filename="{}/{}.air".format(cfg.checkpoint_path,
|
||||
cfg.model_name[:-5]),
|
||||
file_format="AIR")
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export SLOG_PRINT_TO_STDOUT=1
|
||||
python ../eval.py --device_id 0
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export SLOG_PRINT_TO_STDOUT=1
|
||||
python ../src/pre_process_data.py --device_id 0
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export SLOG_PRINT_TO_STDOUT=1
|
||||
python ../train.py --device_id 0
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
__init__.py
|
||||
"""
|
||||
|
||||
from . import musictagger
|
||||
from . import loss
|
||||
from . import dataset
|
||||
from . import config
|
||||
from . import pre_process_data
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py, eval.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
data_cfg = edict({
|
||||
'num_classes': 50,
|
||||
'num_consumer': 4,
|
||||
'get_npy': 1,
|
||||
'get_mindrecord': 1,
|
||||
'audio_path': "/dev/data/Music_Tagger_Data/fea/",
|
||||
'npy_path': "/dev/data/Music_Tagger_Data/fea/",
|
||||
'info_path': "/dev/data/Music_Tagger_Data/fea/",
|
||||
'info_name': 'annotations_final.csv',
|
||||
'device_target': 'Ascend',
|
||||
'device_id': 0,
|
||||
'mr_path': '/dev/data/Music_Tagger_Data/fea/',
|
||||
'mr_name': ['train', 'val'],
|
||||
})
|
||||
|
||||
music_cfg = edict({
|
||||
'pre_trained': False,
|
||||
'lr': 0.0005,
|
||||
'batch_size': 32,
|
||||
'epoch_size': 10,
|
||||
'loss_scale': 1024.0,
|
||||
'num_consumer': 4,
|
||||
'mixed_precision': False,
|
||||
'train_filename': 'train.mindrecord0',
|
||||
'val_filename': 'val.mindrecord0',
|
||||
'data_dir': '/dev/data/Music_Tagger_Data/fea/',
|
||||
'device_target': 'Ascend',
|
||||
'device_id': 0,
|
||||
'keep_checkpoint_max': 10,
|
||||
'save_step': 2000,
|
||||
'checkpoint_path': '/dev/data/Music_Tagger_Data/model',
|
||||
'prefix': 'MusicTagger',
|
||||
'model_name': 'MusicTagger_3-50_543.ckpt',
|
||||
})
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''python dataset.py'''
|
||||
|
||||
import os
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
def create_dataset(base_path, filename, batch_size, columns_list,
|
||||
num_consumer):
|
||||
"""Create dataset"""
|
||||
|
||||
path = os.path.join(base_path, filename)
|
||||
dtrain = ds.MindDataset(path, columns_list, num_consumer)
|
||||
dtrain = dtrain.shuffle(buffer_size=dtrain.get_dataset_size())
|
||||
dtrain = dtrain.batch(batch_size, drop_remainder=True)
|
||||
|
||||
return dtrain
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
define loss
|
||||
"""
|
||||
from mindspore import nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
|
||||
class BCELoss(nn.Cell):
|
||||
"""
|
||||
BCELoss
|
||||
"""
|
||||
def __init__(self, record=None):
|
||||
super(BCELoss, self).__init__(record)
|
||||
self.sm_scalar = P.ScalarSummary()
|
||||
self.cast = P.Cast()
|
||||
self.record = record
|
||||
self.weight = None
|
||||
self.bce = P.BinaryCrossEntropy()
|
||||
|
||||
def construct(self, input_data, target):
|
||||
target = self.cast(target, mstype.float32)
|
||||
loss = self.bce(input_data, target, self.weight)
|
||||
if self.record:
|
||||
self.sm_scalar("loss", loss)
|
||||
return loss
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''model'''
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class MusicTaggerCNN(nn.Cell):
|
||||
"""
|
||||
Music Tagger CNN
|
||||
"""
|
||||
def __init__(self, in_classes, kernel_size, padding, maxpool, has_bias):
|
||||
super(MusicTaggerCNN, self).__init__()
|
||||
self.in_classes = in_classes
|
||||
self.kernel_size = kernel_size
|
||||
self.maxpool = maxpool
|
||||
self.padding = padding
|
||||
self.has_bias = has_bias
|
||||
# build model
|
||||
self.conv1 = nn.Conv2d(self.in_classes[0], self.in_classes[1],
|
||||
self.kernel_size[0])
|
||||
self.conv2 = nn.Conv2d(self.in_classes[1], self.in_classes[2],
|
||||
self.kernel_size[1])
|
||||
self.conv3 = nn.Conv2d(self.in_classes[2], self.in_classes[3],
|
||||
self.kernel_size[2])
|
||||
self.conv4 = nn.Conv2d(self.in_classes[3], self.in_classes[4],
|
||||
self.kernel_size[3])
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(self.in_classes[1])
|
||||
self.bn2 = nn.BatchNorm2d(self.in_classes[2])
|
||||
self.bn3 = nn.BatchNorm2d(self.in_classes[3])
|
||||
self.bn4 = nn.BatchNorm2d(self.in_classes[4])
|
||||
|
||||
self.pool1 = nn.MaxPool2d(maxpool[0], maxpool[0])
|
||||
self.pool2 = nn.MaxPool2d(maxpool[1], maxpool[1])
|
||||
self.pool3 = nn.MaxPool2d(maxpool[2], maxpool[2])
|
||||
self.pool4 = nn.MaxPool2d(maxpool[3], maxpool[3])
|
||||
self.poolreduce = P.ReduceMax(keep_dims=False)
|
||||
self.Act = nn.ReLU()
|
||||
self.flatten = nn.Flatten()
|
||||
self.dense = nn.Dense(2048, 50, activation='sigmoid')
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def construct(self, input_data):
|
||||
"""
|
||||
Music Tagger CNN
|
||||
"""
|
||||
x = self.conv1(input_data)
|
||||
x = self.bn1(x)
|
||||
x = self.Act(x)
|
||||
x = self.pool1(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.Act(x)
|
||||
x = self.pool2(x)
|
||||
|
||||
x = self.conv3(x)
|
||||
x = self.bn3(x)
|
||||
x = self.Act(x)
|
||||
x = self.pool3(x)
|
||||
|
||||
x = self.conv4(x)
|
||||
x = self.bn4(x)
|
||||
x = self.Act(x)
|
||||
x = self.poolreduce(x, (2, 3))
|
||||
x = self.flatten(x)
|
||||
x = self.dense(x)
|
||||
|
||||
return x
|
|
@ -0,0 +1,226 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''python dataset.py'''
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import librosa
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from mindspore import context
|
||||
from src.config import data_cfg as cfg
|
||||
|
||||
|
||||
def compute_melgram(audio_path, save_path='', filename='', save_npy=True):
|
||||
"""
|
||||
extract melgram feature from the audio and save as numpy array
|
||||
|
||||
Args:
|
||||
audio_path (str): path to the audio clip.
|
||||
save_path (str): path to save the numpy array.
|
||||
filename (str): filename of the audio clip.
|
||||
|
||||
Returns:
|
||||
numpy array.
|
||||
|
||||
"""
|
||||
SR = 12000
|
||||
N_FFT = 512
|
||||
N_MELS = 96
|
||||
HOP_LEN = 256
|
||||
DURA = 29.12 # to make it 1366 frame..
|
||||
|
||||
src, _ = librosa.load(audio_path, sr=SR) # whole signal
|
||||
n_sample = src.shape[0]
|
||||
n_sample_fit = int(DURA * SR)
|
||||
|
||||
if n_sample < n_sample_fit: # if too short
|
||||
src = np.hstack((src, np.zeros((int(DURA * SR) - n_sample,))))
|
||||
elif n_sample > n_sample_fit: # if too long
|
||||
src = src[(n_sample - n_sample_fit) // 2:(n_sample + n_sample_fit) //
|
||||
2]
|
||||
logam = librosa.core.amplitude_to_db
|
||||
melgram = librosa.feature.melspectrogram
|
||||
ret = logam(
|
||||
melgram(y=src, sr=SR, hop_length=HOP_LEN, n_fft=N_FFT, n_mels=N_MELS))
|
||||
ret = ret[np.newaxis, np.newaxis, :]
|
||||
if save_npy:
|
||||
|
||||
save_path = save_path + filename[:-4] + '.npy'
|
||||
np.save(save_path, ret)
|
||||
return ret
|
||||
|
||||
|
||||
def get_data(features_data, labels_data):
|
||||
data_list = []
|
||||
for i, (label, feature) in enumerate(zip(labels_data, features_data)):
|
||||
data_json = {"id": i, "feature": feature, "label": label}
|
||||
data_list.append(data_json)
|
||||
return data_list
|
||||
|
||||
|
||||
def convert(s):
|
||||
if s.isdigit():
|
||||
return int(s)
|
||||
return s
|
||||
|
||||
|
||||
def GetLabel(info_path, info_name):
|
||||
"""
|
||||
separate dataset into training set and validation set
|
||||
|
||||
Args:
|
||||
info_path (str): path to the information file.
|
||||
info_name (str): name of the information file.
|
||||
|
||||
"""
|
||||
T = []
|
||||
with open(info_path + '/' + info_name, 'rb') as info:
|
||||
data = info.readline()
|
||||
while data:
|
||||
T.append([
|
||||
convert(i[1:-1])
|
||||
for i in data.strip().decode('utf-8').split("\t")
|
||||
])
|
||||
data = info.readline()
|
||||
|
||||
annotation = pd.DataFrame(T[1:], columns=T[0])
|
||||
count = []
|
||||
for i in annotation.columns[1:-2]:
|
||||
count.append([annotation[i].sum() / len(annotation), i])
|
||||
count = sorted(count)
|
||||
full_label = []
|
||||
for i in count[-50:]:
|
||||
full_label.append(i[1])
|
||||
out = []
|
||||
for i in T[1:]:
|
||||
index = [k for k, x in enumerate(i) if x == 1]
|
||||
label = [T[0][k] for k in index]
|
||||
L = [str(0) for k in range(50)]
|
||||
L.append(i[-1])
|
||||
for j in label:
|
||||
if j in full_label:
|
||||
ind = full_label.index(j)
|
||||
L[ind] = '1'
|
||||
out.append(L)
|
||||
out = np.array(out)
|
||||
|
||||
Train = []
|
||||
Val = []
|
||||
|
||||
for i in out:
|
||||
if np.random.rand() > 0.2:
|
||||
Train.append(i)
|
||||
else:
|
||||
Val.append(i)
|
||||
np.savetxt("{}/music_tagging_train_tmp.csv".format(info_path),
|
||||
np.array(Train),
|
||||
fmt='%s',
|
||||
delimiter=',')
|
||||
np.savetxt("{}/music_tagging_val_tmp.csv".format(info_path),
|
||||
np.array(Val),
|
||||
fmt='%s',
|
||||
delimiter=',')
|
||||
|
||||
|
||||
def generator_md(info_name, file_path, num_classes):
|
||||
"""
|
||||
generate numpy array from features of all audio clips
|
||||
|
||||
Args:
|
||||
info_path (str): path to the information file.
|
||||
file_path (str): path to the npy files.
|
||||
|
||||
Returns:
|
||||
2 numpy array.
|
||||
|
||||
"""
|
||||
df = pd.read_csv(info_name, header=None)
|
||||
df.columns = [str(i) for i in range(num_classes)] + ["mp3_path"]
|
||||
data = []
|
||||
label = []
|
||||
for i in range(len(df)):
|
||||
try:
|
||||
data.append(
|
||||
np.load(file_path + df.mp3_path.values[i][:-4] +
|
||||
'.npy').reshape(1, 96, 1366))
|
||||
label.append(np.array(df[df.columns[:-1]][i:i + 1])[0])
|
||||
except FileNotFoundError:
|
||||
print("Exception occurred in generator_md.")
|
||||
return np.array(data), np.array(label, dtype=np.int32)
|
||||
|
||||
|
||||
def convert_to_mindrecord(info_name, file_path, store_path, mr_name,
|
||||
num_classes):
|
||||
""" convert dataset to mindrecord """
|
||||
num_shard = 4
|
||||
data, label = generator_md(info_name, file_path, num_classes)
|
||||
schema_json = {
|
||||
"id": {
|
||||
"type": "int32"
|
||||
},
|
||||
"feature": {
|
||||
"type": "float32",
|
||||
"shape": [1, 96, 1366]
|
||||
},
|
||||
"label": {
|
||||
"type": "int32",
|
||||
"shape": [num_classes]
|
||||
}
|
||||
}
|
||||
|
||||
writer = FileWriter(
|
||||
os.path.join(store_path, '{}.mindrecord'.format(mr_name)), num_shard)
|
||||
datax = get_data(data, label)
|
||||
writer.add_schema(schema_json, "music_tagger_schema")
|
||||
writer.add_index(["id"])
|
||||
writer.write_raw_data(datax)
|
||||
writer.commit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='get feature')
|
||||
parser.add_argument('--device_id',
|
||||
type=int,
|
||||
help='device ID',
|
||||
default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
if cfg.get_npy:
|
||||
GetLabel(cfg.info_path, cfg.info_name)
|
||||
dirname = os.listdir(cfg.audio_path)
|
||||
for d in dirname:
|
||||
file_name = os.listdir("{}/{}".format(cfg.audio_path, d))
|
||||
if not os.path.isdir("{}/{}".format(cfg.npy_path, d)):
|
||||
os.mkdir("{}/{}".format(cfg.npy_path, d))
|
||||
for f in file_name:
|
||||
compute_melgram("{}/{}/{}".format(cfg.audio_path, d, f),
|
||||
"{}/{}/".format(cfg.npy_path, d), f)
|
||||
|
||||
if cfg.get_mindrecord:
|
||||
if args.device_id is not None:
|
||||
context.set_context(device_target='Ascend',
|
||||
mode=context.GRAPH_MODE,
|
||||
device_id=args.device_id)
|
||||
else:
|
||||
context.set_context(device_target='Ascend',
|
||||
mode=context.GRAPH_MODE,
|
||||
device_id=cfg.device_id)
|
||||
for cmn in cfg.mr_nam:
|
||||
if cmn in ['train', 'val']:
|
||||
convert_to_mindrecord('music_tagging_{}_tmp.csv'.format(cmn),
|
||||
cfg.npy_path, cfg.mr_path, cmn,
|
||||
cfg.num_classes)
|
|
@ -0,0 +1,50 @@
|
|||
choral
|
||||
female voice
|
||||
metal
|
||||
country
|
||||
weird
|
||||
no voice
|
||||
cello
|
||||
harp
|
||||
beats
|
||||
female vocal
|
||||
male voice
|
||||
dance
|
||||
new age
|
||||
voice
|
||||
choir
|
||||
classic
|
||||
man
|
||||
solo
|
||||
sitar
|
||||
soft
|
||||
no vocal
|
||||
pop
|
||||
male vocal
|
||||
woman
|
||||
flute
|
||||
quiet
|
||||
loud
|
||||
harpsichord
|
||||
no vocals
|
||||
vocals
|
||||
singing
|
||||
male
|
||||
opera
|
||||
indian
|
||||
female
|
||||
synth
|
||||
vocal
|
||||
violin
|
||||
beat
|
||||
ambient
|
||||
piano
|
||||
fast
|
||||
rock
|
||||
electronic
|
||||
drums
|
||||
strings
|
||||
techno
|
||||
slow
|
||||
classical
|
||||
guitar
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''
|
||||
##############train models#################
|
||||
python train.py
|
||||
'''
|
||||
import argparse
|
||||
from mindspore import context, nn
|
||||
from mindspore.train import Model
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from src.dataset import create_dataset
|
||||
from src.musictagger import MusicTaggerCNN
|
||||
from src.loss import BCELoss
|
||||
from src.config import music_cfg as cfg
|
||||
|
||||
def train(model, dataset_direct, filename, columns_list, num_consumer=4,
|
||||
batch=16, epoch=50, save_checkpoint_steps=2172, keep_checkpoint_max=50,
|
||||
prefix="model", directory='./'):
|
||||
"""
|
||||
train network
|
||||
"""
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps,
|
||||
keep_checkpoint_max=keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=prefix,
|
||||
directory=directory,
|
||||
config=config_ck)
|
||||
data_train = create_dataset(dataset_direct, filename, batch, columns_list,
|
||||
num_consumer)
|
||||
|
||||
|
||||
model.train(epoch,
|
||||
data_train,
|
||||
callbacks=[
|
||||
ckpoint_cb,
|
||||
LossMonitor(per_print_times=181),
|
||||
TimeMonitor()
|
||||
],
|
||||
dataset_sink_mode=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
set_seed(1)
|
||||
parser = argparse.ArgumentParser(description='Train model')
|
||||
parser.add_argument('--device_id',
|
||||
type=int,
|
||||
help='device ID',
|
||||
default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.device_id is not None:
|
||||
context.set_context(device_target='Ascend',
|
||||
mode=context.GRAPH_MODE,
|
||||
device_id=args.device_id)
|
||||
else:
|
||||
context.set_context(device_target='Ascend',
|
||||
mode=context.GRAPH_MODE,
|
||||
device_id=cfg.device_id)
|
||||
|
||||
context.set_context(enable_auto_mixed_precision=cfg.mixed_precision)
|
||||
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
||||
kernel_size=[3, 3, 3, 3, 3],
|
||||
padding=[0] * 5,
|
||||
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
||||
has_bias=True)
|
||||
|
||||
if cfg.pre_trained:
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path + '/' +
|
||||
cfg.model_name)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
net_loss = BCELoss()
|
||||
|
||||
network.set_train(True)
|
||||
net_opt = nn.Adam(params=network.trainable_params(),
|
||||
learning_rate=cfg.lr,
|
||||
loss_scale=cfg.loss_scale)
|
||||
|
||||
loss_scale_manager = FixedLossScaleManager(loss_scale=cfg.loss_scale,
|
||||
drop_overflow_update=False)
|
||||
net_model = Model(network, net_loss, net_opt, loss_scale_manager=loss_scale_manager)
|
||||
|
||||
train(model=net_model,
|
||||
dataset_direct=cfg.data_dir,
|
||||
filename=cfg.train_filename,
|
||||
columns_list=['feature', 'label'],
|
||||
num_consumer=cfg.num_consumer,
|
||||
batch=cfg.batch_size,
|
||||
epoch=cfg.epoch_size,
|
||||
save_checkpoint_steps=cfg.save_step,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max,
|
||||
prefix=cfg.prefix,
|
||||
directory=cfg.checkpoint_path + "_{}".format(cfg.device_id))
|
||||
print("train success")
|
Loading…
Reference in New Issue