forked from mindspore-Ecosystem/mindspore
!6759 Add music auto tagging network in modelzoo
Merge pull request !6759 from jiangzhenguang/add_music_auto_tagging
This commit is contained in:
commit
2b6a88e09e
|
@ -0,0 +1,203 @@
|
||||||
|
# Contents
|
||||||
|
|
||||||
|
- [Music Auto Tagging Description](#fcn-4-description)
|
||||||
|
- [Model Architecture](#model-architecture)
|
||||||
|
- [Features](#features)
|
||||||
|
- [Mixed Precision](#mixed-precision)
|
||||||
|
- [Environment Requirements](#environment-requirements)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Script Description](#script-description)
|
||||||
|
- [Script and Sample Code](#script-and-sample-code)
|
||||||
|
- [Script Parameters](#script-parameters)
|
||||||
|
- [Training Process](#training-process)
|
||||||
|
- [Training](#training)
|
||||||
|
- [Evaluation Process](#evaluation-process)
|
||||||
|
- [Evaluation](#evaluation)
|
||||||
|
- [Model Description](#model-description)
|
||||||
|
- [Performance](#performance)
|
||||||
|
- [Evaluation Performance](#evaluation-performance)
|
||||||
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
|
||||||
|
# [Music Auto Tagging Description](#contents)
|
||||||
|
|
||||||
|
This repository provides a script and recipe to train the Music Auto Tagging model to achieve state-of-the-art accuracy.
|
||||||
|
|
||||||
|
[Paper](https://arxiv.org/abs/1606.00298): `"Keunwoo Choi, George Fazekas, and Mark Sandler, “Automatic tagging using deep convolutional neural networks,” in International Society of Music Information Retrieval Conference. ISMIR, 2016."
|
||||||
|
|
||||||
|
|
||||||
|
# [Model Architecture](#contents)
|
||||||
|
|
||||||
|
Music Auto Tagging is a convolutional neural network architecture, its name Music Auto Tagging comes from the fact that it has 4 layers. Its layers consists of Convolutional layers, Max Pooling layers, Activation layers, Fully connected layers.
|
||||||
|
|
||||||
|
# [Features](#contents)
|
||||||
|
|
||||||
|
## Mixed Precision
|
||||||
|
|
||||||
|
The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
|
||||||
|
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
|
||||||
|
|
||||||
|
|
||||||
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
|
- Hardware(Ascend
|
||||||
|
- If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||||
|
- Framework
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- For more information, please check the resources below:
|
||||||
|
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||||
|
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# [Quick Start](#contents)
|
||||||
|
|
||||||
|
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||||
|
|
||||||
|
### 1. Download and preprocess the dataset
|
||||||
|
|
||||||
|
1. down load the classification dataset (for instance, MagnaTagATune Dataset, Million Song Dataset, etc)
|
||||||
|
2. Extract the dataset
|
||||||
|
3. The information file of each clip should contain the label and path. Please refer to the annotations_final.csv in MagnaTagATune Dataset.
|
||||||
|
4. The provided pre-processing script use MagnaTagATune Dataset as an example. Please modify the code accprding to your own need.
|
||||||
|
|
||||||
|
### 2. setup parameters (src/config.py)
|
||||||
|
|
||||||
|
### 3. Train
|
||||||
|
|
||||||
|
after having your dataset, first convert the audio clip into mindrecord dataset by using the following codes
|
||||||
|
```shell
|
||||||
|
python pre_process_data.py --device_id 0
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, you can start training the model by using the following codes
|
||||||
|
```shell
|
||||||
|
SLOG_PRINT_TO_STDOUT=1 python train.py --device_id 0
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Test
|
||||||
|
|
||||||
|
Then you can test your model
|
||||||
|
```shell
|
||||||
|
SLOG_PRINT_TO_STDOUT=1 python eval.py --device_id 0
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Script Description](#contents)
|
||||||
|
|
||||||
|
## [Script and Sample Code](#contents)
|
||||||
|
|
||||||
|
```
|
||||||
|
├── model_zoo
|
||||||
|
├── README.md // descriptions about all the models
|
||||||
|
├── music_auto_tagging
|
||||||
|
├── README.md // descriptions about googlenet
|
||||||
|
├── scripts
|
||||||
|
│ ├──run_train.sh // shell script for distributed on Ascend
|
||||||
|
│ ├──run_eval.sh // shell script for evaluation on Ascend
|
||||||
|
│ ├──run_process_data.sh // shell script for convert audio clips to mindrecord
|
||||||
|
├── src
|
||||||
|
│ ├──dataset.py // creating dataset
|
||||||
|
│ ├──pre_process_data.py // pre-process dataset
|
||||||
|
│ ├──musictagger.py // googlenet architecture
|
||||||
|
│ ├──config.py // parameter configuration
|
||||||
|
│ ├──loss.py // loss function
|
||||||
|
│ ├──tag.txt // tag for each number
|
||||||
|
├── train.py // training script
|
||||||
|
├── eval.py // evaluation script
|
||||||
|
├── export.py // export model in air format
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Script Parameters](#contents)
|
||||||
|
|
||||||
|
Parameters for both training and evaluation can be set in config.py
|
||||||
|
|
||||||
|
- config for Music Auto tagging
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
'num_classes': 50 # number of tagging classes
|
||||||
|
'num_consumer': 4 # file number for mindrecord
|
||||||
|
'get_npy': 1 # mode for converting to npy, default 1 in this case
|
||||||
|
'get_mindrecord': 1 # mode for converting npy file into mindrecord file,default 1 in this case
|
||||||
|
'audio_path': "/dev/data/Music_Tagger_Data/fea/" # path to audio clips
|
||||||
|
'npy_path': "/dev/data/Music_Tagger_Data/fea/" # path to numpy
|
||||||
|
'info_path': "/dev/data/Music_Tagger_Data/fea/" # path to info_name, which provide the label of each audio clips
|
||||||
|
'info_name': 'annotations_final.csv' # info_name
|
||||||
|
'device_target': 'Ascend' # device running the program
|
||||||
|
'device_id': 0 # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training
|
||||||
|
'mr_path': '/dev/data/Music_Tagger_Data/fea/' # path to mindrecord
|
||||||
|
'mr_name': ['train', 'val'] # mindrecord name
|
||||||
|
|
||||||
|
'pre_trained': False # whether training based on the pre-trained model
|
||||||
|
'lr': 0.0005 # learning rate
|
||||||
|
'batch_size': 32 # training batch size
|
||||||
|
'epoch_size': 10 # total training epochs
|
||||||
|
'loss_scale': 1024.0 # loss scale
|
||||||
|
'num_consumer': 4 # file number for mindrecord
|
||||||
|
'mixed_precision': False # if use mix precision calculation
|
||||||
|
'train_filename': 'train.mindrecord0' # file name of the train mindrecord data
|
||||||
|
'val_filename': 'val.mindrecord0' # file name of the evaluation mindrecord data
|
||||||
|
'data_dir': '/dev/data/Music_Tagger_Data/fea/' # directory of mindrecord data
|
||||||
|
'device_target': 'Ascend' # device running the program
|
||||||
|
'device_id': 0, # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training
|
||||||
|
'keep_checkpoint_max': 10, # only keep the last keep_checkpoint_max checkpoint
|
||||||
|
'save_step': 2000, # steps for saving checkpoint
|
||||||
|
'checkpoint_path': '/dev/data/Music_Tagger_Data/model/', # the absolute full path to save the checkpoint file
|
||||||
|
'prefix': 'MusicTagger', # prefix of checkpoint
|
||||||
|
'model_name': 'MusicTagger_3-50_543.ckpt', # checkpoint name
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## [Training Process](#contents)
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
- running on Ascend
|
||||||
|
|
||||||
|
```
|
||||||
|
python train.py > train.log 2>&1 &
|
||||||
|
```
|
||||||
|
|
||||||
|
The python command above will run in the background, you can view the results through the file `train.log`.
|
||||||
|
|
||||||
|
After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
# grep "loss is " train.log
|
||||||
|
epoch: 1 step: 100, loss is 0.23264095
|
||||||
|
epoch: 1 step: 200, loss is 0.2013525
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
The model checkpoint will be saved in the set directory.
|
||||||
|
|
||||||
|
## [Evaluation Process](#contents)
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
|
||||||
|
|
||||||
|
# [Model Description](#contents)
|
||||||
|
## [Performance](#contents)
|
||||||
|
|
||||||
|
### Evaluation Performance
|
||||||
|
|
||||||
|
| Parameters | Ascend |
|
||||||
|
| -------------------------- | ----------------------------------------------------------- |
|
||||||
|
| Model Version | FCN-4 |
|
||||||
|
| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G |
|
||||||
|
| uploaded Date | 09/11/2020 (month/day/year) |
|
||||||
|
| MindSpore Version | r0.7.0 |
|
||||||
|
| Training Parameters | epoch=10, steps=534, batch_size = 32, lr=0.005 |
|
||||||
|
| Optimizer | Adam |
|
||||||
|
| Loss Function | Binary cross entropy |
|
||||||
|
| outputs | probability |
|
||||||
|
| Loss | AUC 0.909 |
|
||||||
|
| Speed | 1pc: 160 samples/sec; |
|
||||||
|
| Total time | 1pc: 20 mins; |
|
||||||
|
| Checkpoint for Fine tuning | 198.73M(.ckpt file) |
|
||||||
|
| Scripts | [music_auto_tagging script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/audio/music_auto_tagging) |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# [ModelZoo Homepage](#contents)
|
||||||
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,137 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''
|
||||||
|
##############evaluate trained models#################
|
||||||
|
python eval.py
|
||||||
|
'''
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from src.musictagger import MusicTaggerCNN
|
||||||
|
from src.config import music_cfg as cfg
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_auc(labels_list, preds_list):
|
||||||
|
"""
|
||||||
|
The AUC calculation function
|
||||||
|
Input:
|
||||||
|
labels_list: list of true label
|
||||||
|
preds_list: list of predicted label
|
||||||
|
Outputs
|
||||||
|
Float, means of AUC
|
||||||
|
"""
|
||||||
|
auc = []
|
||||||
|
n_bins = labels_list.shape[0] // 2
|
||||||
|
if labels_list.ndim == 1:
|
||||||
|
labels_list = labels_list.reshape(-1, 1)
|
||||||
|
preds_list = preds_list.reshape(-1, 1)
|
||||||
|
for i in range(labels_list.shape[1]):
|
||||||
|
labels = labels_list[:, i]
|
||||||
|
preds = preds_list[:, i]
|
||||||
|
postive_len = labels.sum()
|
||||||
|
negative_len = labels.shape[0] - postive_len
|
||||||
|
total_case = postive_len * negative_len
|
||||||
|
positive_histogram = np.zeros((n_bins))
|
||||||
|
negative_histogram = np.zeros((n_bins))
|
||||||
|
bin_width = 1.0 / n_bins
|
||||||
|
|
||||||
|
for j, _ in enumerate(labels):
|
||||||
|
nth_bin = int(preds[j] // bin_width)
|
||||||
|
if labels[j]:
|
||||||
|
positive_histogram[nth_bin] = positive_histogram[nth_bin] + 1
|
||||||
|
else:
|
||||||
|
negative_histogram[nth_bin] = negative_histogram[nth_bin] + 1
|
||||||
|
|
||||||
|
accumulated_negative = 0
|
||||||
|
satisfied_pair = 0
|
||||||
|
for k in range(n_bins):
|
||||||
|
satisfied_pair += (
|
||||||
|
positive_histogram[k] * accumulated_negative +
|
||||||
|
positive_histogram[k] * negative_histogram[k] * 0.5)
|
||||||
|
accumulated_negative += negative_histogram[k]
|
||||||
|
auc.append(satisfied_pair / total_case)
|
||||||
|
|
||||||
|
return np.mean(auc)
|
||||||
|
|
||||||
|
|
||||||
|
def val(net, data_dir, filename, num_consumer=4, batch=32):
|
||||||
|
"""
|
||||||
|
Validation function, estimate the performance of trained model
|
||||||
|
|
||||||
|
Input:
|
||||||
|
net: the trained neural network
|
||||||
|
data_dir: path to the validation dataset
|
||||||
|
filename: name of the validation dataset
|
||||||
|
num_consumer: split number of validation dataset
|
||||||
|
batch: validation batch size
|
||||||
|
Outputs
|
||||||
|
Float, AUC
|
||||||
|
"""
|
||||||
|
data_train = create_dataset(data_dir, filename, 32, ['feature', 'label'],
|
||||||
|
num_consumer)
|
||||||
|
data_train = data_train.create_tuple_iterator()
|
||||||
|
res_pred = []
|
||||||
|
res_true = []
|
||||||
|
for data, label in data_train:
|
||||||
|
x = net(Tensor(data, dtype=mstype.float32))
|
||||||
|
res_pred.append(x.asnumpy())
|
||||||
|
res_true.append(label.asnumpy())
|
||||||
|
res_pred = np.concatenate(res_pred, axis=0)
|
||||||
|
res_true = np.concatenate(res_true, axis=0)
|
||||||
|
auc = calculate_auc(res_true, res_pred)
|
||||||
|
return auc
|
||||||
|
|
||||||
|
|
||||||
|
def validation(net, model_path, data_dir, filename, num_consumer, batch):
|
||||||
|
param_dict = load_checkpoint(model_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
|
auc = val(net, data_dir, filename, num_consumer, batch)
|
||||||
|
return auc
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description='Evaluate model')
|
||||||
|
parser.add_argument('--device_id',
|
||||||
|
type=int,
|
||||||
|
help='device ID',
|
||||||
|
default=None)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.device_id is not None:
|
||||||
|
context.set_context(device_target=cfg.device_target,
|
||||||
|
mode=context.GRAPH_MODE,
|
||||||
|
device_id=args.device_id)
|
||||||
|
else:
|
||||||
|
context.set_context(device_target=cfg.device_target,
|
||||||
|
mode=context.GRAPH_MODE,
|
||||||
|
device_id=cfg.device_id)
|
||||||
|
|
||||||
|
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
||||||
|
kernel_size=[3, 3, 3, 3, 3],
|
||||||
|
padding=[0] * 5,
|
||||||
|
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
||||||
|
has_bias=True)
|
||||||
|
network.set_train(False)
|
||||||
|
auc_val = validation(network, cfg.checkpoint_path + "/" + cfg.model_name, cfg.data_dir,
|
||||||
|
cfg.val_filename, cfg.num_consumer, cfg.batch_size)
|
||||||
|
|
||||||
|
print("=" * 10 + "Validation Peformance" + "=" * 10)
|
||||||
|
print("AUC: {:.5f}".format(auc_val))
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''
|
||||||
|
##############evaluate trained models#################
|
||||||
|
python export.py
|
||||||
|
'''
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mindspore.train.serialization import export
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from src.musictagger import MusicTaggerCNN
|
||||||
|
from src.config import music_cfg as cfg
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
||||||
|
kernel_size=[3, 3, 3, 3, 3],
|
||||||
|
padding=[0] * 5,
|
||||||
|
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
||||||
|
has_bias=True)
|
||||||
|
param_dict = load_checkpoint(cfg.checkpoint_path + "/" + cfg.model_name)
|
||||||
|
load_param_into_net(network, param_dict)
|
||||||
|
input_data = np.random.uniform(0.0, 1.0, size=[1, 1, 96, 1366]).astype(np.float32)
|
||||||
|
export(network,
|
||||||
|
Tensor(input_data),
|
||||||
|
filename="{}/{}.air".format(cfg.checkpoint_path,
|
||||||
|
cfg.model_name[:-5]),
|
||||||
|
file_format="AIR")
|
|
@ -0,0 +1,18 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
export SLOG_PRINT_TO_STDOUT=1
|
||||||
|
python ../eval.py --device_id 0
|
|
@ -0,0 +1,18 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
export SLOG_PRINT_TO_STDOUT=1
|
||||||
|
python ../src/pre_process_data.py --device_id 0
|
|
@ -0,0 +1,18 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
export SLOG_PRINT_TO_STDOUT=1
|
||||||
|
python ../train.py --device_id 0
|
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
__init__.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from . import musictagger
|
||||||
|
from . import loss
|
||||||
|
from . import dataset
|
||||||
|
from . import config
|
||||||
|
from . import pre_process_data
|
|
@ -0,0 +1,53 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
network config setting, will be used in train.py, eval.py
|
||||||
|
"""
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
|
||||||
|
data_cfg = edict({
|
||||||
|
'num_classes': 50,
|
||||||
|
'num_consumer': 4,
|
||||||
|
'get_npy': 1,
|
||||||
|
'get_mindrecord': 1,
|
||||||
|
'audio_path': "/dev/data/Music_Tagger_Data/fea/",
|
||||||
|
'npy_path': "/dev/data/Music_Tagger_Data/fea/",
|
||||||
|
'info_path': "/dev/data/Music_Tagger_Data/fea/",
|
||||||
|
'info_name': 'annotations_final.csv',
|
||||||
|
'device_target': 'Ascend',
|
||||||
|
'device_id': 0,
|
||||||
|
'mr_path': '/dev/data/Music_Tagger_Data/fea/',
|
||||||
|
'mr_name': ['train', 'val'],
|
||||||
|
})
|
||||||
|
|
||||||
|
music_cfg = edict({
|
||||||
|
'pre_trained': False,
|
||||||
|
'lr': 0.0005,
|
||||||
|
'batch_size': 32,
|
||||||
|
'epoch_size': 10,
|
||||||
|
'loss_scale': 1024.0,
|
||||||
|
'num_consumer': 4,
|
||||||
|
'mixed_precision': False,
|
||||||
|
'train_filename': 'train.mindrecord0',
|
||||||
|
'val_filename': 'val.mindrecord0',
|
||||||
|
'data_dir': '/dev/data/Music_Tagger_Data/fea/',
|
||||||
|
'device_target': 'Ascend',
|
||||||
|
'device_id': 0,
|
||||||
|
'keep_checkpoint_max': 10,
|
||||||
|
'save_step': 2000,
|
||||||
|
'checkpoint_path': '/dev/data/Music_Tagger_Data/model',
|
||||||
|
'prefix': 'MusicTagger',
|
||||||
|
'model_name': 'MusicTagger_3-50_543.ckpt',
|
||||||
|
})
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''python dataset.py'''
|
||||||
|
|
||||||
|
import os
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(base_path, filename, batch_size, columns_list,
|
||||||
|
num_consumer):
|
||||||
|
"""Create dataset"""
|
||||||
|
|
||||||
|
path = os.path.join(base_path, filename)
|
||||||
|
dtrain = ds.MindDataset(path, columns_list, num_consumer)
|
||||||
|
dtrain = dtrain.shuffle(buffer_size=dtrain.get_dataset_size())
|
||||||
|
dtrain = dtrain.batch(batch_size, drop_remainder=True)
|
||||||
|
|
||||||
|
return dtrain
|
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
define loss
|
||||||
|
"""
|
||||||
|
from mindspore import nn
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BCELoss(nn.Cell):
|
||||||
|
"""
|
||||||
|
BCELoss
|
||||||
|
"""
|
||||||
|
def __init__(self, record=None):
|
||||||
|
super(BCELoss, self).__init__(record)
|
||||||
|
self.sm_scalar = P.ScalarSummary()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.record = record
|
||||||
|
self.weight = None
|
||||||
|
self.bce = P.BinaryCrossEntropy()
|
||||||
|
|
||||||
|
def construct(self, input_data, target):
|
||||||
|
target = self.cast(target, mstype.float32)
|
||||||
|
loss = self.bce(input_data, target, self.weight)
|
||||||
|
if self.record:
|
||||||
|
self.sm_scalar("loss", loss)
|
||||||
|
return loss
|
|
@ -0,0 +1,83 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''model'''
|
||||||
|
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class MusicTaggerCNN(nn.Cell):
|
||||||
|
"""
|
||||||
|
Music Tagger CNN
|
||||||
|
"""
|
||||||
|
def __init__(self, in_classes, kernel_size, padding, maxpool, has_bias):
|
||||||
|
super(MusicTaggerCNN, self).__init__()
|
||||||
|
self.in_classes = in_classes
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.maxpool = maxpool
|
||||||
|
self.padding = padding
|
||||||
|
self.has_bias = has_bias
|
||||||
|
# build model
|
||||||
|
self.conv1 = nn.Conv2d(self.in_classes[0], self.in_classes[1],
|
||||||
|
self.kernel_size[0])
|
||||||
|
self.conv2 = nn.Conv2d(self.in_classes[1], self.in_classes[2],
|
||||||
|
self.kernel_size[1])
|
||||||
|
self.conv3 = nn.Conv2d(self.in_classes[2], self.in_classes[3],
|
||||||
|
self.kernel_size[2])
|
||||||
|
self.conv4 = nn.Conv2d(self.in_classes[3], self.in_classes[4],
|
||||||
|
self.kernel_size[3])
|
||||||
|
|
||||||
|
self.bn1 = nn.BatchNorm2d(self.in_classes[1])
|
||||||
|
self.bn2 = nn.BatchNorm2d(self.in_classes[2])
|
||||||
|
self.bn3 = nn.BatchNorm2d(self.in_classes[3])
|
||||||
|
self.bn4 = nn.BatchNorm2d(self.in_classes[4])
|
||||||
|
|
||||||
|
self.pool1 = nn.MaxPool2d(maxpool[0], maxpool[0])
|
||||||
|
self.pool2 = nn.MaxPool2d(maxpool[1], maxpool[1])
|
||||||
|
self.pool3 = nn.MaxPool2d(maxpool[2], maxpool[2])
|
||||||
|
self.pool4 = nn.MaxPool2d(maxpool[3], maxpool[3])
|
||||||
|
self.poolreduce = P.ReduceMax(keep_dims=False)
|
||||||
|
self.Act = nn.ReLU()
|
||||||
|
self.flatten = nn.Flatten()
|
||||||
|
self.dense = nn.Dense(2048, 50, activation='sigmoid')
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def construct(self, input_data):
|
||||||
|
"""
|
||||||
|
Music Tagger CNN
|
||||||
|
"""
|
||||||
|
x = self.conv1(input_data)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.Act(x)
|
||||||
|
x = self.pool1(x)
|
||||||
|
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
x = self.Act(x)
|
||||||
|
x = self.pool2(x)
|
||||||
|
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = self.bn3(x)
|
||||||
|
x = self.Act(x)
|
||||||
|
x = self.pool3(x)
|
||||||
|
|
||||||
|
x = self.conv4(x)
|
||||||
|
x = self.bn4(x)
|
||||||
|
x = self.Act(x)
|
||||||
|
x = self.poolreduce(x, (2, 3))
|
||||||
|
x = self.flatten(x)
|
||||||
|
x = self.dense(x)
|
||||||
|
|
||||||
|
return x
|
|
@ -0,0 +1,226 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''python dataset.py'''
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
from mindspore.mindrecord import FileWriter
|
||||||
|
from mindspore import context
|
||||||
|
from src.config import data_cfg as cfg
|
||||||
|
|
||||||
|
|
||||||
|
def compute_melgram(audio_path, save_path='', filename='', save_npy=True):
|
||||||
|
"""
|
||||||
|
extract melgram feature from the audio and save as numpy array
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_path (str): path to the audio clip.
|
||||||
|
save_path (str): path to save the numpy array.
|
||||||
|
filename (str): filename of the audio clip.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy array.
|
||||||
|
|
||||||
|
"""
|
||||||
|
SR = 12000
|
||||||
|
N_FFT = 512
|
||||||
|
N_MELS = 96
|
||||||
|
HOP_LEN = 256
|
||||||
|
DURA = 29.12 # to make it 1366 frame..
|
||||||
|
|
||||||
|
src, _ = librosa.load(audio_path, sr=SR) # whole signal
|
||||||
|
n_sample = src.shape[0]
|
||||||
|
n_sample_fit = int(DURA * SR)
|
||||||
|
|
||||||
|
if n_sample < n_sample_fit: # if too short
|
||||||
|
src = np.hstack((src, np.zeros((int(DURA * SR) - n_sample,))))
|
||||||
|
elif n_sample > n_sample_fit: # if too long
|
||||||
|
src = src[(n_sample - n_sample_fit) // 2:(n_sample + n_sample_fit) //
|
||||||
|
2]
|
||||||
|
logam = librosa.core.amplitude_to_db
|
||||||
|
melgram = librosa.feature.melspectrogram
|
||||||
|
ret = logam(
|
||||||
|
melgram(y=src, sr=SR, hop_length=HOP_LEN, n_fft=N_FFT, n_mels=N_MELS))
|
||||||
|
ret = ret[np.newaxis, np.newaxis, :]
|
||||||
|
if save_npy:
|
||||||
|
|
||||||
|
save_path = save_path + filename[:-4] + '.npy'
|
||||||
|
np.save(save_path, ret)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_data(features_data, labels_data):
|
||||||
|
data_list = []
|
||||||
|
for i, (label, feature) in enumerate(zip(labels_data, features_data)):
|
||||||
|
data_json = {"id": i, "feature": feature, "label": label}
|
||||||
|
data_list.append(data_json)
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
|
def convert(s):
|
||||||
|
if s.isdigit():
|
||||||
|
return int(s)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def GetLabel(info_path, info_name):
|
||||||
|
"""
|
||||||
|
separate dataset into training set and validation set
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info_path (str): path to the information file.
|
||||||
|
info_name (str): name of the information file.
|
||||||
|
|
||||||
|
"""
|
||||||
|
T = []
|
||||||
|
with open(info_path + '/' + info_name, 'rb') as info:
|
||||||
|
data = info.readline()
|
||||||
|
while data:
|
||||||
|
T.append([
|
||||||
|
convert(i[1:-1])
|
||||||
|
for i in data.strip().decode('utf-8').split("\t")
|
||||||
|
])
|
||||||
|
data = info.readline()
|
||||||
|
|
||||||
|
annotation = pd.DataFrame(T[1:], columns=T[0])
|
||||||
|
count = []
|
||||||
|
for i in annotation.columns[1:-2]:
|
||||||
|
count.append([annotation[i].sum() / len(annotation), i])
|
||||||
|
count = sorted(count)
|
||||||
|
full_label = []
|
||||||
|
for i in count[-50:]:
|
||||||
|
full_label.append(i[1])
|
||||||
|
out = []
|
||||||
|
for i in T[1:]:
|
||||||
|
index = [k for k, x in enumerate(i) if x == 1]
|
||||||
|
label = [T[0][k] for k in index]
|
||||||
|
L = [str(0) for k in range(50)]
|
||||||
|
L.append(i[-1])
|
||||||
|
for j in label:
|
||||||
|
if j in full_label:
|
||||||
|
ind = full_label.index(j)
|
||||||
|
L[ind] = '1'
|
||||||
|
out.append(L)
|
||||||
|
out = np.array(out)
|
||||||
|
|
||||||
|
Train = []
|
||||||
|
Val = []
|
||||||
|
|
||||||
|
for i in out:
|
||||||
|
if np.random.rand() > 0.2:
|
||||||
|
Train.append(i)
|
||||||
|
else:
|
||||||
|
Val.append(i)
|
||||||
|
np.savetxt("{}/music_tagging_train_tmp.csv".format(info_path),
|
||||||
|
np.array(Train),
|
||||||
|
fmt='%s',
|
||||||
|
delimiter=',')
|
||||||
|
np.savetxt("{}/music_tagging_val_tmp.csv".format(info_path),
|
||||||
|
np.array(Val),
|
||||||
|
fmt='%s',
|
||||||
|
delimiter=',')
|
||||||
|
|
||||||
|
|
||||||
|
def generator_md(info_name, file_path, num_classes):
|
||||||
|
"""
|
||||||
|
generate numpy array from features of all audio clips
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info_path (str): path to the information file.
|
||||||
|
file_path (str): path to the npy files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
2 numpy array.
|
||||||
|
|
||||||
|
"""
|
||||||
|
df = pd.read_csv(info_name, header=None)
|
||||||
|
df.columns = [str(i) for i in range(num_classes)] + ["mp3_path"]
|
||||||
|
data = []
|
||||||
|
label = []
|
||||||
|
for i in range(len(df)):
|
||||||
|
try:
|
||||||
|
data.append(
|
||||||
|
np.load(file_path + df.mp3_path.values[i][:-4] +
|
||||||
|
'.npy').reshape(1, 96, 1366))
|
||||||
|
label.append(np.array(df[df.columns[:-1]][i:i + 1])[0])
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("Exception occurred in generator_md.")
|
||||||
|
return np.array(data), np.array(label, dtype=np.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_mindrecord(info_name, file_path, store_path, mr_name,
|
||||||
|
num_classes):
|
||||||
|
""" convert dataset to mindrecord """
|
||||||
|
num_shard = 4
|
||||||
|
data, label = generator_md(info_name, file_path, num_classes)
|
||||||
|
schema_json = {
|
||||||
|
"id": {
|
||||||
|
"type": "int32"
|
||||||
|
},
|
||||||
|
"feature": {
|
||||||
|
"type": "float32",
|
||||||
|
"shape": [1, 96, 1366]
|
||||||
|
},
|
||||||
|
"label": {
|
||||||
|
"type": "int32",
|
||||||
|
"shape": [num_classes]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
writer = FileWriter(
|
||||||
|
os.path.join(store_path, '{}.mindrecord'.format(mr_name)), num_shard)
|
||||||
|
datax = get_data(data, label)
|
||||||
|
writer.add_schema(schema_json, "music_tagger_schema")
|
||||||
|
writer.add_index(["id"])
|
||||||
|
writer.write_raw_data(datax)
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description='get feature')
|
||||||
|
parser.add_argument('--device_id',
|
||||||
|
type=int,
|
||||||
|
help='device ID',
|
||||||
|
default=None)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if cfg.get_npy:
|
||||||
|
GetLabel(cfg.info_path, cfg.info_name)
|
||||||
|
dirname = os.listdir(cfg.audio_path)
|
||||||
|
for d in dirname:
|
||||||
|
file_name = os.listdir("{}/{}".format(cfg.audio_path, d))
|
||||||
|
if not os.path.isdir("{}/{}".format(cfg.npy_path, d)):
|
||||||
|
os.mkdir("{}/{}".format(cfg.npy_path, d))
|
||||||
|
for f in file_name:
|
||||||
|
compute_melgram("{}/{}/{}".format(cfg.audio_path, d, f),
|
||||||
|
"{}/{}/".format(cfg.npy_path, d), f)
|
||||||
|
|
||||||
|
if cfg.get_mindrecord:
|
||||||
|
if args.device_id is not None:
|
||||||
|
context.set_context(device_target='Ascend',
|
||||||
|
mode=context.GRAPH_MODE,
|
||||||
|
device_id=args.device_id)
|
||||||
|
else:
|
||||||
|
context.set_context(device_target='Ascend',
|
||||||
|
mode=context.GRAPH_MODE,
|
||||||
|
device_id=cfg.device_id)
|
||||||
|
for cmn in cfg.mr_nam:
|
||||||
|
if cmn in ['train', 'val']:
|
||||||
|
convert_to_mindrecord('music_tagging_{}_tmp.csv'.format(cmn),
|
||||||
|
cfg.npy_path, cfg.mr_path, cmn,
|
||||||
|
cfg.num_classes)
|
|
@ -0,0 +1,50 @@
|
||||||
|
choral
|
||||||
|
female voice
|
||||||
|
metal
|
||||||
|
country
|
||||||
|
weird
|
||||||
|
no voice
|
||||||
|
cello
|
||||||
|
harp
|
||||||
|
beats
|
||||||
|
female vocal
|
||||||
|
male voice
|
||||||
|
dance
|
||||||
|
new age
|
||||||
|
voice
|
||||||
|
choir
|
||||||
|
classic
|
||||||
|
man
|
||||||
|
solo
|
||||||
|
sitar
|
||||||
|
soft
|
||||||
|
no vocal
|
||||||
|
pop
|
||||||
|
male vocal
|
||||||
|
woman
|
||||||
|
flute
|
||||||
|
quiet
|
||||||
|
loud
|
||||||
|
harpsichord
|
||||||
|
no vocals
|
||||||
|
vocals
|
||||||
|
singing
|
||||||
|
male
|
||||||
|
opera
|
||||||
|
indian
|
||||||
|
female
|
||||||
|
synth
|
||||||
|
vocal
|
||||||
|
violin
|
||||||
|
beat
|
||||||
|
ambient
|
||||||
|
piano
|
||||||
|
fast
|
||||||
|
rock
|
||||||
|
electronic
|
||||||
|
drums
|
||||||
|
strings
|
||||||
|
techno
|
||||||
|
slow
|
||||||
|
classical
|
||||||
|
guitar
|
|
@ -0,0 +1,109 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
'''
|
||||||
|
##############train models#################
|
||||||
|
python train.py
|
||||||
|
'''
|
||||||
|
import argparse
|
||||||
|
from mindspore import context, nn
|
||||||
|
from mindspore.train import Model
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
from src.musictagger import MusicTaggerCNN
|
||||||
|
from src.loss import BCELoss
|
||||||
|
from src.config import music_cfg as cfg
|
||||||
|
|
||||||
|
def train(model, dataset_direct, filename, columns_list, num_consumer=4,
|
||||||
|
batch=16, epoch=50, save_checkpoint_steps=2172, keep_checkpoint_max=50,
|
||||||
|
prefix="model", directory='./'):
|
||||||
|
"""
|
||||||
|
train network
|
||||||
|
"""
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps,
|
||||||
|
keep_checkpoint_max=keep_checkpoint_max)
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix=prefix,
|
||||||
|
directory=directory,
|
||||||
|
config=config_ck)
|
||||||
|
data_train = create_dataset(dataset_direct, filename, batch, columns_list,
|
||||||
|
num_consumer)
|
||||||
|
|
||||||
|
|
||||||
|
model.train(epoch,
|
||||||
|
data_train,
|
||||||
|
callbacks=[
|
||||||
|
ckpoint_cb,
|
||||||
|
LossMonitor(per_print_times=181),
|
||||||
|
TimeMonitor()
|
||||||
|
],
|
||||||
|
dataset_sink_mode=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
set_seed(1)
|
||||||
|
parser = argparse.ArgumentParser(description='Train model')
|
||||||
|
parser.add_argument('--device_id',
|
||||||
|
type=int,
|
||||||
|
help='device ID',
|
||||||
|
default=None)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.device_id is not None:
|
||||||
|
context.set_context(device_target='Ascend',
|
||||||
|
mode=context.GRAPH_MODE,
|
||||||
|
device_id=args.device_id)
|
||||||
|
else:
|
||||||
|
context.set_context(device_target='Ascend',
|
||||||
|
mode=context.GRAPH_MODE,
|
||||||
|
device_id=cfg.device_id)
|
||||||
|
|
||||||
|
context.set_context(enable_auto_mixed_precision=cfg.mixed_precision)
|
||||||
|
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
||||||
|
kernel_size=[3, 3, 3, 3, 3],
|
||||||
|
padding=[0] * 5,
|
||||||
|
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
||||||
|
has_bias=True)
|
||||||
|
|
||||||
|
if cfg.pre_trained:
|
||||||
|
param_dict = load_checkpoint(cfg.checkpoint_path + '/' +
|
||||||
|
cfg.model_name)
|
||||||
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
|
net_loss = BCELoss()
|
||||||
|
|
||||||
|
network.set_train(True)
|
||||||
|
net_opt = nn.Adam(params=network.trainable_params(),
|
||||||
|
learning_rate=cfg.lr,
|
||||||
|
loss_scale=cfg.loss_scale)
|
||||||
|
|
||||||
|
loss_scale_manager = FixedLossScaleManager(loss_scale=cfg.loss_scale,
|
||||||
|
drop_overflow_update=False)
|
||||||
|
net_model = Model(network, net_loss, net_opt, loss_scale_manager=loss_scale_manager)
|
||||||
|
|
||||||
|
train(model=net_model,
|
||||||
|
dataset_direct=cfg.data_dir,
|
||||||
|
filename=cfg.train_filename,
|
||||||
|
columns_list=['feature', 'label'],
|
||||||
|
num_consumer=cfg.num_consumer,
|
||||||
|
batch=cfg.batch_size,
|
||||||
|
epoch=cfg.epoch_size,
|
||||||
|
save_checkpoint_steps=cfg.save_step,
|
||||||
|
keep_checkpoint_max=cfg.keep_checkpoint_max,
|
||||||
|
prefix=cfg.prefix,
|
||||||
|
directory=cfg.checkpoint_path + "_{}".format(cfg.device_id))
|
||||||
|
print("train success")
|
Loading…
Reference in New Issue