!7456 Add dscnn network to modelzoo

Merge pull request !7456 from zhanghuiyao/dscnn_branch
This commit is contained in:
mindspore-ci-bot 2020-10-20 16:15:10 +08:00 committed by Gitee
commit 9d8707d101
18 changed files with 2291 additions and 0 deletions

View File

@ -0,0 +1,314 @@
# Contents
- [DS-CNN Description](#DS-CNN-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [How to use](#how-to-use)
- [Inference](#inference)
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
- [Transfer Learning](#transfer-learning)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [DS-CNN Description](#contents)
DS-CNN, depthwise separable convolutional neural network, was first used in Keyword Spotting in 2017. KWS application has highly constrained power budget and typically runs on tiny microcontrollers with limited memory and compute capability. depthwise separable convolutions are more efficient both in number of parameters and operations, which makes deeper and wider architecture possible even in the resource-constrained microcontroller devices.
[Paper](https://arxiv.org/abs/1711.07128): Zhang, Yundong, Naveen Suda, Liangzhen Lai, and Vikas Chandra. "Hello edge: Keyword spotting on microcontrollers." arXiv preprint arXiv:1711.07128 (2017).
# [Model Architecture](#contents)
The overall network architecture of DS-CNN is show below:
[Link](https://arxiv.org/abs/1711.07128)
# [Dataset](#contents)
Dataset used: [Speech commands dataset version 1](<https://ai.googleblog.com/2017/08/launching-speech-commands-dataset.html>)
- Dataset size2.02GiB, 65,000 one-second long utterances of 30 short words, by thousands of different people
- Train 80%
- Val 10%
- Test 10%
- Data formatWAVE format file, with the sample data encoded as linear 16-bit single-channel PCM values, at a 16 KHz rate
- NoteData will be processed in download_process_data.py
Dataset used: [Speech commands dataset version 2](<https://arxiv.org/abs/1804.03209>)
- Dataset size 8.17 GiB. 105,829 a one-second (or less) long utterances of 35 words by 2,618 speakers
- Train 80%
- Val 10%
- Test 10%
- Data formatWAVE format file, with the sample data encoded as linear 16-bit single-channel PCM values, at a 16 KHz rate
- NoteData will be processed in download_process_data.py
# [Environment Requirements](#contents)
- HardwareAscend/GPU
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- Third party open source packageif have
- numpy
- soundfile
- python_speech_features
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
First set the config for data, train, eval in src/config.py
- download and process dataset
```
python src/download_process_data.py
```
- running on Ascend
```python
# run training example
python train.py
# run evaluation example
# if you want to eval a specific model, you should specify model_dir to the ckpt path:
python eval.py --model_dir your_ckpt_path
# if you want to eval all the model you saved, you should specify model_dir to the folder where the models are saved.
python eval.py --model_dir your_models_folder_path
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```
├── MODELZOO_DSCNN_MS_MTI
├── README.md // descriptions about ds-cnn
├── scripts
│ ├──run_download_process_data.sh // shell script for download dataset and prepare feature and label
│ ├──run_train_ascend.sh // shell script for train on ascend
│ ├──run_eval_ascend.sh // shell script for evaluation on ascend
├── src
│ ├──callback.py // callbacks
│ ├──config.py // parameter configuration of data, train and eval
│ ├──dataset.py // creating dataset
│ ├──download_process_data.py // download and prepare train, val, test data
│ ├──ds_cnn.py // dscnn architecture
│ ├──log.py // logging class
│ ├──loss.py // loss function
│ ├──lr_scheduler.py // lr_scheduler
│ ├──models.py // load ckpt
│ ├──utils.py // some function for prepare data
├── train.py // training script
├── eval.py // evaluation script
├── export.py // export checkpoint files into air/geir
├── requirements.txt // Third party open source package
```
## [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py.
- config for dataset for Speech commands dataset version 1
```python
'data_url': 'http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz'
# Location of speech training data archive on the web
'data_dir': 'data' # Where to download the dataset
'feat_dir': 'feat' # Where to save the feature and label of audios
'background_volume': 0.1 # How loud the background noise should be, between 0 and 1.
'background_frequency': 0.8 # How many of the training samples have background noise mixed in.
'silence_percentage': 10.0 # How much of the training data should be silence.
'unknown_percentage': 10.0 # How much of the training data should be unknown words
'time_shift_ms': 100.0 # Range to randomly shift the training audio by in time
'testing_percentage': 10 # What percentage of wavs to use as a test set
'validation_percentage': 10 # What percentage of wavs to use as a validation set
'wanted_words': 'yes,no,up,down,left,right,on,off,stop,go'
# Words to use (others will be added to an unknown label)
'sample_rate': 16000 # Expected sample rate of the wavs
'device_id': 1000 # device ID used to train or evaluate the dataset.
'clip_duration_ms': 10 # Expected duration in milliseconds of the wavs
'window_size_ms': 40.0 # How long each spectrogram timeslice is
'window_stride_ms': 20.0 # How long each spectrogram timeslice is
'dct_coefficient_count': 20 # How many bins to use for the MFCC fingerprint
```
- config for DS-CNN and train parameters of Speech commands dataset version 1
```python
'model_size_info': [6, 276, 10, 4, 2, 1, 276, 3, 3, 2, 2, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1]
# Model dimensions - different for various models
'drop': 0.9 # dropout
'pretrained': '' # model_path, local pretrained model to load
'use_graph_mode': 1 # use graph mode or feed mode
'val_interval': 1 # validate interval
'per_batch_size': 100 # batch size for per gpu
'lr_scheduler': 'multistep' # lr-scheduler, option type: multistep, cosine_annealing
'lr': 0.1 # learning rate of the training
'lr_epochs': '20,40,60,80' # epoch of lr changing
'lr_gamma': 0.1 # decrease lr by a factor of exponential lr_scheduler
'eta_min': 0 # eta_min in cosine_annealing scheduler
'T_max': 80 # T-max in cosine_annealing scheduler
'max_epoch': 80 # max epoch num to train the model
'warmup_epochs': 0 # warmup epoch
'weight_decay': 0.001 # weight decay
'momentum': 0.98 # weight decay
'log_interval': 100 # logging interval
'ckpt_path': 'train_outputs' # the location where checkpoint and log will be saved
'ckpt_interval': 100 # save ckpt_interval
```
- config for DS-CNN and evaluation parameters of Speech commands dataset version 1
```python
'feat_dir': 'feat' # Where to save the feature of audios
'model_dir': '' # which folder the models are saved in or specific path of one model
'wanted_words': 'yes,no,up,down,left,right,on,off,stop,go'
# Words to use (others will be added to an unknown label)
'sample_rate': 16000 # Expected sample rate of the wavs
'device_id': 1000 # device ID used to train or evaluate the dataset.
'clip_duration_ms': 10 # Expected duration in milliseconds of the wavs
'window_size_ms': 40.0 # How long each spectrogram timeslice is
'window_stride_ms': 20.0 # How long each spectrogram timeslice is
'dct_coefficient_count': 20 # How many bins to use for the MFCC fingerprint
'model_size_info': [6, 276, 10, 4, 2, 1, 276, 3, 3, 2, 2, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1]
# Model dimensions - different for various models
'pre_batch_size': 100 # batch size for eval
'drop': 0.9 # dropout in train
'log_path': 'eval_outputs' # path to save eval log
```
## [Training Process](#contents)
### Training
- running on Ascend
for shell script:
```python
# sh srcipts/run_train_ascend.sh [device_id]
sh srcipts/run_train_ascend.sh 0
```
for python script:
```python
# python train.py --device_id [device_id]
python train.py --device_id 0
```
you can see the args and loss, acc info on your screen, you also can view the results in folder train_outputs
```python
epoch[1], iter[443], loss:0.73811543, mean_wps:12102.26 wavs/sec
Eval: top1_cor:737, top5_cor:1699, tot:3000, acc@1=24.57%, acc@5=56.63%
epoch[2], iter[665], loss:0.381568, mean_wps:12107.45 wavs/sec
Eval: top1_cor:1355, top5_cor:2615, tot:3000, acc@1=45.17%, acc@5=87.17%
...
...
Best epoch:41 acc:93.73%
```
The checkpoints and log will be saved in the train_outputs.
## [Evaluation Process](#contents)
### Evaluation
- evaluation on Speech commands dataset version 1 when running on Ascend
Before running the command below, please check the checkpoint path used for evaluation. Please set model_dir in config.py or pass model_dir in your command line.
for shell scripts:
```python
# sh scripts/run_eval_ascend.sh device_id model_dir
sh scripts/run_eval_ascend.sh 0 train_outputs/*/*.ckpt
or
sh scripts/run_eval_ascend.sh 0 train_outputs/*/
```
for python scripts:
```python
# python eval.py --device_id device_id --model_dir model_dir
python eval.py --device_id 0 --model_dir train_outputs/*/*.ckpt
or
python eval.py --device_id 0 --model_dir train_outputs/*
```
You can view the results on the screen or from logs in eval_outputs folder. The accuracy of the test dataset will be as follows:
```python
Eval: top1_cor:2805, top5_cor:2963, tot:3000, acc@1=93.50%, acc@5=98.77%
Best model:train_outputs/*/epoch41-1_223.ckpt acc:93.50%
```
# [Model Description](#contents)
## [Performance](#contents)
### Train Performance
| Parameters | Ascend |
| -------------------------- | ------------------------------------------------------------ |
| Model Version | DS-CNN |
| Resource | Ascend 910 CPU 2.60GHz56coresMemory314G |
| uploaded Date | 27/09/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | Speech commands dataset version 1 |
| Training Parameters | epoch=80, batch_size = 100, lr=0.1 |
| Optimizer | Momentum |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Loss | 0.0019 |
| Speed | 2s/epoch |
| Total time | 4 mins |
| Parameters (K) | 500K |
| Checkpoint for Fine tuning | 3.3M (.ckpt file) |
| Script | [Link]() | [Link]() |
### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | DS-CNN |
| Resource | Ascend 910 |
| Uploaded Date | 09/27/2020 |
| MindSpore Version | 1.0.0 |
| Dataset |Speech commands dataset version 1 |
| Training Parameters | src/config.py |
| outputs | probability |
| Accuracy | 93.96% |
| Total time | 3min |
| Params (K) | 500K |
|Checkpoint for Fine tuning (M) | 3.3M |
# [Description of Random Situation](#contents)
In download_process_data.py, we set the seed for split train, val, test set.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,112 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DSCNN eval."""
import os
import datetime
import glob
import argparse
import numpy as np
from mindspore import context
from mindspore import Tensor, Model
from mindspore.common import dtype as mstype
from src.config import eval_config
from src.log import get_logger
from src.dataset import audio_dataset
from src.ds_cnn import DSCNN
from src.models import load_ckpt
def get_top5_acc(top5_arg, gt_class):
sub_count = 0
for top5, gt in zip(top5_arg, gt_class):
if gt in top5:
sub_count += 1
return sub_count
def val(args, model, test_de):
'''Eval.'''
eval_dataloader = test_de.create_tuple_iterator()
img_tot = 0
top1_correct = 0
top5_correct = 0
for data, gt_classes in eval_dataloader:
output = model.predict(Tensor(data, mstype.float32))
output = output.asnumpy()
top1_output = np.argmax(output, (-1))
top5_output = np.argsort(output)[:, -5:]
gt_classes = gt_classes.asnumpy()
t1_correct = np.equal(top1_output, gt_classes).sum()
top1_correct += t1_correct
top5_correct += get_top5_acc(top5_output, gt_classes)
img_tot += output.shape[0]
results = [[top1_correct], [top5_correct], [img_tot]]
results = np.array(results)
top1_correct = results[0, 0]
top5_correct = results[1, 0]
img_tot = results[2, 0]
acc1 = 100.0 * top1_correct / img_tot
acc5 = 100.0 * top5_correct / img_tot
if acc1 > args.best_acc:
args.best_acc = acc1
args.best_index = args.index
args.logger.info('Eval: top1_cor:{}, top5_cor:{}, tot:{}, acc@1={:.2f}%, acc@5={:.2f}%' \
.format(top1_correct, top5_correct, img_tot, acc1, acc5))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--device_id', type=int, default=1, help='which device the model will be trained on')
args, model_settings = eval_config(parser)
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci", device_id=args.device_id)
# Logger
args.outputs_dir = os.path.join(args.log_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir)
# show args
args.logger.save_args(args)
# find model path
if os.path.isdir(args.model_dir):
models = list(glob.glob(os.path.join(args.model_dir, '*.ckpt')))
print(models)
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[0].split('epoch')[-1])
args.models = sorted(models, key=f)
else:
args.models = [args.model_dir]
args.best_acc = 0
args.index = 0
args.best_index = 0
for model_path in args.models:
test_de = audio_dataset(args.feat_dir, 'testing', model_settings['spectrogram_length'],
model_settings['dct_coefficient_count'], args.per_batch_size)
network = DSCNN(model_settings, args.model_size_info)
load_ckpt(network, model_path, False)
network.set_train(False)
model = Model(network)
args.logger.info('load model {} success'.format(model_path))
val(args, model, test_de)
args.index += 1
args.logger.info('Best model:{} acc:{:.2f}%'.format(args.models[args.best_index], args.best_acc))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,33 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DSCNN export."""
import argparse
import numpy as np
from mindspore import Tensor
from mindspore.train.serialization import export
from src.config import eval_config
from src.ds_cnn import DSCNN
from src.models import load_ckpt
parser = argparse.ArgumentParser()
args, model_settings = eval_config(parser)
network = DSCNN(model_settings, args.model_size_info)
load_ckpt(network, args.model_dir, False)
x = np.random.uniform(0.0, 1.0, size=[1, 1, model_settings['spectrogram_length'],
model_settings['dct_coefficient_count']]).astype(np.float32)
export(network, Tensor(x), file_name=args.model_dir.replace('.ckpt', '.air'), file_format='AIR')

View File

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
python src/download_process_data.py

View File

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
python eval.py --device_id $1 --model_dir $2

View File

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
python train.py --device_id $1

View File

@ -0,0 +1,88 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Callback."""
import time
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import CheckpointConfig, Callback
class ProgressMonitor(Callback):
'''Progress Monitor.'''
def __init__(self, args):
super(ProgressMonitor, self).__init__()
self.args = args
self.epoch_start_time = 0
self.step_start_time = 0
self.globe_step_cnt = 0
self.local_step_cnt = 0
self.ckpt_history = []
def begin(self, run_context):
if not self.args.epoch_cnt:
self.args.logger.info('start network train...')
if run_context is None:
pass
def step_begin(self, run_context):
if self.local_step_cnt == 0:
self.step_start_time = time.time()
if run_context is None:
pass
def step_end(self, run_context):
'''Callback when step end.'''
if self.local_step_cnt % self.args.log_interval == 0 and self.local_step_cnt > 0:
cb_params = run_context.original_args()
time_used = time.time() - self.step_start_time
fps_mean = self.args.per_batch_size * self.args.log_interval / time_used
self.args.logger.info('epoch[{}], iter[{}], loss:{}, mean_wps:{:.2f} wavs/sec'.format(self.args.epoch_cnt,
self.globe_step_cnt +
self.local_step_cnt,
cb_params.net_outputs,
fps_mean))
self.step_start_time = time.time()
self.local_step_cnt += 1
def epoch_begin(self, run_context):
self.epoch_start_time = time.time()
if run_context is None:
pass
def epoch_end(self, run_context):
'''Callback when epoch end.'''
cb_params = run_context.original_args()
self.globe_step_cnt = self.args.steps_per_epoch * (self.args.epoch_cnt + 1) - 1
time_used = time.time() - self.epoch_start_time
fps_mean = self.args.per_batch_size * self.args.steps_per_epoch / time_used
self.args.logger.info(
'epoch[{}], iter[{}], loss:{}, mean_wps:{:.2f} wavs/sec'.format(self.args.epoch_cnt, self.globe_step_cnt,
cb_params.net_outputs, fps_mean))
self.args.epoch_cnt += 1
self.local_step_cnt = 0
def end(self, run_context):
pass
def callback_func(args, cb, prefix):
callbacks = [cb]
if args.rank_save_ckpt_flag:
ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.outputs_dir, prefix=prefix)
callbacks.append(ckpt_cb)
return callbacks

View File

@ -0,0 +1,161 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Config setting, will be used in train.py and eval.py"""
from src.utils import prepare_words_list
def data_config(parser):
'''config for data.'''
parser.add_argument('--data_url', type=str,
default='http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
help='Location of speech training data archive on the web.')
parser.add_argument('--data_dir', type=str, default='data',
help='Where to download the dataset.')
parser.add_argument('--feat_dir', type=str, default='feat',
help='Where to save the feature of audios')
parser.add_argument('--background_volume', type=float, default=0.1,
help='How loud the background noise should be, between 0 and 1.')
parser.add_argument('--background_frequency', type=float, default=0.8,
help='How many of the training samples have background noise mixed in.')
parser.add_argument('--silence_percentage', type=float, default=10.0,
help='How much of the training data should be silence.')
parser.add_argument('--unknown_percentage', type=float, default=10.0,
help='How much of the training data should be unknown words.')
parser.add_argument('--time_shift_ms', type=float, default=100.0,
help='Range to randomly shift the training audio by in time.')
parser.add_argument('--testing_percentage', type=int, default=10,
help='What percentage of wavs to use as a test set.')
parser.add_argument('--validation_percentage', type=int, default=10,
help='What percentage of wavs to use as a validation set.')
parser.add_argument('--wanted_words', type=str, default='yes,no,up,down,left,right,on,off,stop,go',
help='Words to use (others will be added to an unknown label)')
parser.add_argument('--sample_rate', type=int, default=16000, help='Expected sample rate of the wavs')
parser.add_argument('--clip_duration_ms', type=int, default=1000,
help='Expected duration in milliseconds of the wavs')
parser.add_argument('--window_size_ms', type=float, default=40.0, help='How long each spectrogram timeslice is')
parser.add_argument('--window_stride_ms', type=float, default=20.0, help='How long each spectrogram timeslice is')
parser.add_argument('--dct_coefficient_count', type=int, default=20,
help='How many bins to use for the MFCC fingerprint')
def train_config(parser):
'''config for train.'''
data_config(parser)
# network related
parser.add_argument('--model_size_info', type=int, nargs="+",
default=[6, 276, 10, 4, 2, 1, 276, 3, 3, 2, 2, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1,
1, 276, 3, 3, 1, 1],
help='Model dimensions - different for various models')
parser.add_argument('--drop', type=float, default=0.9, help='dropout')
parser.add_argument('--pretrained', type=str, default='', help='model_path, local pretrained model to load')
# training related
parser.add_argument('--use_graph_mode', default=1, type=int, help='use graph mode or feed mode')
parser.add_argument('--val_interval', type=int, default=1, help='validate interval')
# dataset related
parser.add_argument('--per_batch_size', default=100, type=int, help='batch size for per gpu')
# optimizer and lr related
parser.add_argument('--lr_scheduler', default='multistep', type=str,
help='lr-scheduler, option type: multistep, cosine_annealing')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate of the training')
parser.add_argument('--lr_epochs', type=str, default='20,40,60,80', help='epoch of lr changing')
parser.add_argument('--lr_gamma', type=float, default=0.1,
help='decrease lr by a factor of exponential lr_scheduler')
parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler')
parser.add_argument('--T_max', type=int, default=80, help='T-max in cosine_annealing scheduler')
parser.add_argument('--max_epoch', type=int, default=80, help='max epoch num to train the model')
parser.add_argument('--warmup_epochs', default=0, type=float, help='warmup epoch')
parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay')
parser.add_argument('--momentum', type=float, default=0.98, help='momentum')
# logging related
parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
parser.add_argument('--ckpt_path', type=str, default='train_outputs/', help='checkpoint save location')
parser.add_argument('--ckpt_interval', type=int, default=100, help='save ckpt_interval')
flags, _ = parser.parse_known_args()
flags.dataset_sink_mode = bool(flags.use_graph_mode)
flags.lr_epochs = list(map(int, flags.lr_epochs.split(',')))
model_settings = prepare_model_settings(
len(prepare_words_list(flags.wanted_words.split(','))),
flags.sample_rate, flags.clip_duration_ms, flags.window_size_ms,
flags.window_stride_ms, flags.dct_coefficient_count)
model_settings['dropout1'] = flags.drop
return flags, model_settings
def eval_config(parser):
'''config for eval.'''
parser.add_argument('--feat_dir', type=str, default='feat',
help='Where to save the feature of audios')
parser.add_argument('--model_dir', type=str,
default='outputs',
help='which folder the models are saved in or specific path of one model')
parser.add_argument('--wanted_words', type=str, default='yes,no,up,down,left,right,on,off,stop,go',
help='Words to use (others will be added to an unknown label)')
parser.add_argument('--sample_rate', type=int, default=16000, help='Expected sample rate of the wavs')
parser.add_argument('--clip_duration_ms', type=int, default=1000,
help='Expected duration in milliseconds of the wavs')
parser.add_argument('--window_size_ms', type=float, default=40.0, help='How long each spectrogram timeslice is')
parser.add_argument('--window_stride_ms', type=float, default=20.0, help='How long each spectrogram timeslice is')
parser.add_argument('--dct_coefficient_count', type=int, default=20,
help='How many bins to use for the MFCC fingerprint')
parser.add_argument('--model_size_info', type=int, nargs="+",
default=[6, 276, 10, 4, 2, 1, 276, 3, 3, 2, 2, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1,
1, 276, 3, 3, 1, 1],
help='Model dimensions - different for various models')
parser.add_argument('--per_batch_size', default=100, type=int, help='batch size for per gpu')
parser.add_argument('--drop', type=float, default=0.9, help='dropout')
# logging related
parser.add_argument('--log_path', type=str, default='eval_outputs/', help='path to save eval log')
flags, _ = parser.parse_known_args()
model_settings = prepare_model_settings(
len(prepare_words_list(flags.wanted_words.split(','))),
flags.sample_rate, flags.clip_duration_ms, flags.window_size_ms,
flags.window_stride_ms, flags.dct_coefficient_count)
model_settings['dropout1'] = flags.drop
return flags, model_settings
def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
window_size_ms, window_stride_ms,
dct_coefficient_count):
'''Prepare model setting.'''
desired_samples = int(sample_rate * clip_duration_ms / 1000)
window_size_samples = int(sample_rate * window_size_ms / 1000)
window_stride_samples = int(sample_rate * window_stride_ms / 1000)
length_minus_window = (desired_samples - window_size_samples)
if length_minus_window < 0:
spectrogram_length = 0
else:
spectrogram_length = 1 + int(length_minus_window / window_stride_samples)
fingerprint_size = dct_coefficient_count * spectrogram_length
return {
'desired_samples': desired_samples,
'window_size_samples': window_size_samples,
'window_stride_samples': window_stride_samples,
'spectrogram_length': spectrogram_length,
'dct_coefficient_count': dct_coefficient_count,
'fingerprint_size': fingerprint_size,
'label_count': label_count,
'sample_rate': sample_rate,
}

View File

@ -0,0 +1,47 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DSCNN dataset."""
import os
import numpy as np
import mindspore.dataset as de
class NpyDataset():
'''Dataset from numpy.'''
def __init__(self, data_dir, data_type, h, w):
super(NpyDataset, self).__init__()
self.data = np.load(os.path.join(data_dir, '{}_data.npy'.format(data_type)))
self.data = np.reshape(self.data, (-1, 1, h, w))
self.label = np.load(os.path.join(data_dir, '{}_label.npy'.format(data_type)))
def __len__(self):
return self.data.shape[0]
def __getitem__(self, item):
data = self.data[item]
label = self.label[item]
# return data, label
return data.astype(np.float32), label.astype(np.int32)
def audio_dataset(data_dir, data_type, h, w, batch_size):
if 'testing' in data_dir:
shuffle = False
else:
shuffle = True
dataset = NpyDataset(data_dir, data_type, h, w)
de_dataset = de.GeneratorDataset(dataset, ["feats", "labels"], shuffle=shuffle)
de_dataset = de_dataset.batch(batch_size, drop_remainder=False)
return de_dataset

View File

@ -0,0 +1,276 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Download process data."""
import hashlib
import math
import os.path
import random
import re
import sys
import tarfile
from glob import glob
import logging
import argparse
from six.moves import urllib
import numpy as np
import soundfile as sf
from python_speech_features import mfcc
from tqdm import tqdm
from src.config import train_config, prepare_model_settings
from src.utils import prepare_words_list
FLAGS = None
MAX_NUM_WAVS_PER_CLASS = 2 ** 27 - 1 # ~134M
SILENCE_LABEL = '_silence_'
SILENCE_INDEX = 0
UNKNOWN_WORD_LABEL = '_unknown_'
UNKNOWN_WORD_INDEX = 1
BACKGROUND_NOISE_DIR_NAME = '_background_noise_'
RANDOM_SEED = 59185
K = 0
def which_set(filename, validation_percentage, testing_percentage):
'''Which set.'''
base_name = os.path.basename(filename)
hash_name = re.sub(r'_nohash_.*$', '', base_name)
hash_name_hashed = hashlib.sha1(bytes(hash_name, 'utf-8')).hexdigest()
percentage_hash = ((int(hash_name_hashed, 16) %
(MAX_NUM_WAVS_PER_CLASS + 1)) *
(100.0 / MAX_NUM_WAVS_PER_CLASS))
if percentage_hash < validation_percentage:
result = 'validation'
elif percentage_hash < (testing_percentage + validation_percentage):
result = 'testing'
else:
result = 'training'
return result
class AudioProcessor():
"""Handles loading, partitioning, and preparing audio training data."""
def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage,
wanted_words, validation_percentage, testing_percentage,
model_settings):
self.data_dir = data_dir
self.maybe_download_and_extract_dataset(data_url, data_dir)
self.prepare_data_index(silence_percentage, unknown_percentage,
wanted_words, validation_percentage,
testing_percentage)
self.prepare_background_data()
self.prepare_data(model_settings)
def maybe_download_and_extract_dataset(self, data_url, dest_directory):
'''Maybe download and extract dataset.'''
if not data_url:
return
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = data_url.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write(
'\r>> Downloading %s %.1f%%' %
(filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
try:
filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
except:
logging.error('Failed to download URL: %s to folder: %s', data_url,
filepath)
logging.error('Please make sure you have enough free space and'
' an internet connection')
raise
print()
statinfo = os.stat(filepath)
logging.info('Successfully downloaded %s (%d bytes)', filename,
statinfo.st_size)
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def prepare_data_index(self, silence_percentage, unknown_percentage,
wanted_words, validation_percentage,
testing_percentage):
'''Prepare data index.'''
# Make sure the shuffling and picking of unknowns is deterministic.
random.seed(RANDOM_SEED)
wanted_words_index = {}
for index, wanted_word in enumerate(wanted_words):
wanted_words_index[wanted_word] = index + 2
self.data_index = {'validation': [], 'testing': [], 'training': []}
unknown_index = {'validation': [], 'testing': [], 'training': []}
all_words = {}
# Look through all the subfolders to find audio samples
search_path = os.path.join(self.data_dir, '*', '*.wav')
for wav_path in glob(search_path):
_, word = os.path.split(os.path.dirname(wav_path))
word = word.lower()
# Treat the '_background_noise_' folder as a special case, since we expect
# it to contain long audio samples we mix in to improve training.
if word == BACKGROUND_NOISE_DIR_NAME:
continue
all_words[word] = True
set_index = which_set(wav_path, validation_percentage, testing_percentage)
# If it's a known class, store its detail, otherwise add it to the list
# we'll use to train the unknown label.
if word in wanted_words_index:
self.data_index[set_index].append({'label': word, 'file': wav_path})
else:
unknown_index[set_index].append({'label': word, 'file': wav_path})
if not all_words:
raise Exception('No .wavs found at ' + search_path)
for index, wanted_word in enumerate(wanted_words):
if wanted_word not in all_words:
raise Exception('Expected to find ' + wanted_word +
' in labels but only found ' +
', '.join(all_words.keys()))
# We need an arbitrary file to load as the input for the silence samples.
# It's multiplied by zero later, so the content doesn't matter.
silence_wav_path = self.data_index['training'][0]['file']
for set_index in ['validation', 'testing', 'training']:
set_size = len(self.data_index[set_index])
silence_size = int(math.ceil(set_size * silence_percentage / 100))
for _ in range(silence_size):
self.data_index[set_index].append({
'label': SILENCE_LABEL,
'file': silence_wav_path
})
# Pick some unknowns to add to each partition of the data set.
random.shuffle(unknown_index[set_index])
unknown_size = int(math.ceil(set_size * unknown_percentage / 100))
self.data_index[set_index].extend(unknown_index[set_index][:unknown_size])
# Make sure the ordering is random.
for set_index in ['validation', 'testing', 'training']:
random.shuffle(self.data_index[set_index])
# Prepare the rest of the result data structure.
self.words_list = prepare_words_list(wanted_words)
self.word_to_index = {}
for word in all_words:
if word in wanted_words_index:
self.word_to_index[word] = wanted_words_index[word]
else:
self.word_to_index[word] = UNKNOWN_WORD_INDEX
self.word_to_index[SILENCE_LABEL] = SILENCE_INDEX
def prepare_background_data(self):
'''Prepare background data.'''
self.background_data = []
background_dir = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME)
if not os.path.exists(background_dir):
return self.background_data
search_path = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME,
'*.wav')
for wav_path in glob(search_path):
wav_data, _ = sf.read(wav_path)
self.background_data.append(wav_data)
if not self.background_data:
raise Exception('No background wav files were found in ' + search_path)
return None
def prepare_single_sample(self, wav_filename, foreground_volume, time_shift_padding, time_shift_offset,
desired_samples, background_data, background_volume):
'''Prepare single sample.'''
wav_data, _ = sf.read(wav_filename)
if len(wav_data) < desired_samples:
wav_data = np.pad(wav_data, [0, desired_samples - len(wav_data)], 'constant')
scaled_foreground = wav_data * foreground_volume
padded_foreground = np.pad(scaled_foreground, time_shift_padding, 'constant')
sliced_foreground = padded_foreground[time_shift_offset: time_shift_offset + desired_samples]
background_add = background_data[0] * background_volume + sliced_foreground
background_clamp = np.clip(background_add, -1.0, 1.0)
# feature = mfcc(background_clamp, samplerate=FLAGS.sample_rate, winlen=0.03, winstep=0.01, numcep=40, nfilt=40).flatten()
feature = mfcc(background_clamp, samplerate=FLAGS.sample_rate, winlen=FLAGS.window_size_ms / 1000,
winstep=FLAGS.window_stride_ms / 1000,
numcep=FLAGS.dct_coefficient_count, nfilt=40, nfft=1024, lowfreq=20, highfreq=7000).flatten()
return feature
def prepare_data(self, model_settings):
'''Prepare data.'''
# Pick one of the partitions to choose samples from.
time_shift = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
background_frequency = FLAGS.background_frequency
background_volume_range = FLAGS.background_volume
desired_samples = model_settings['desired_samples']
if not os.path.exists(FLAGS.feat_dir):
os.makedirs(FLAGS.feat_dir, exist_ok=True)
for mode in ['training', 'validation', 'testing']:
candidates = self.data_index[mode]
sample_count = len(candidates)
# Data and labels will be populated and returned.
data = np.zeros((sample_count, model_settings['fingerprint_size']))
labels = np.zeros(sample_count)
use_background = self.background_data and (mode == 'training')
for i in tqdm(range(sample_count)):
# Pick which audio sample to use.
sample_index = i
sample = candidates[sample_index]
# If we're time shifting, set up the offset for this sample.
if time_shift > 0:
time_shift_amount = np.random.randint(-time_shift, time_shift)
else:
time_shift_amount = 0
if time_shift_amount > 0:
time_shift_padding = [[time_shift_amount, 0]]
time_shift_offset = 0
else:
time_shift_padding = [[0, -time_shift_amount]]
time_shift_offset = -time_shift_amount
if use_background:
background_index = np.random.randint(len(self.background_data))
background_samples = self.background_data[background_index]
background_offset = np.random.randint(
0, len(background_samples) - model_settings['desired_samples'])
background_clipped = background_samples[background_offset:(
background_offset + desired_samples)]
background_reshaped = background_clipped.reshape([desired_samples, 1])
if np.random.uniform(0, 1) < background_frequency:
background_volume = np.random.uniform(0, background_volume_range)
else:
background_volume = 0
else:
background_reshaped = np.zeros([desired_samples, 1])
background_volume = 0
if sample['label'] == SILENCE_LABEL:
foreground_volume = 0
else:
foreground_volume = 1
data[i, :] = self.prepare_single_sample(sample['file'], foreground_volume, time_shift_padding,
time_shift_offset, desired_samples,
background_reshaped, background_volume)
label_index = self.word_to_index[sample['label']]
labels[i] = label_index
np.save(os.path.join(FLAGS.feat_dir, '{}_data.npy'.format(mode)), data)
np.save(os.path.join(FLAGS.feat_dir, '{}_label.npy'.format(mode)), labels)
if __name__ == '__main__':
print('start download_process')
parser = argparse.ArgumentParser()
train_config(parser)
FLAGS, unparsed = parser.parse_known_args()
model_settings_1 = prepare_model_settings(
len(prepare_words_list(FLAGS.wanted_words.split(','))),
FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
audio_processor = AudioProcessor(
FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage,
FLAGS.unknown_percentage,
FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
FLAGS.testing_percentage, model_settings_1)

View File

@ -0,0 +1,107 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DSCNN network."""
import math
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
from mindspore import Parameter
class DepthWiseConv(nn.Cell):
'''Build DepthWise conv.'''
def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
super(DepthWiseConv, self).__init__()
self.has_bias = has_bias
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size,
stride=stride, pad_mode=pad_mode, pad=pad)
self.bias_add = P.BiasAdd()
weight_shape = [channel_multiplier, in_planes, kernel_size[0], kernel_size[1]]
self.weight = Parameter(initializer('ones', weight_shape), name='weight')
if has_bias:
bias_shape = [channel_multiplier * in_planes]
self.bias = Parameter(initializer('zeros', bias_shape), name='bias')
else:
self.bias = None
def construct(self, x):
output = self.depthwise_conv(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
return output
class DSCNN(nn.Cell):
'''Build DSCNN network.'''
def __init__(self, model_settings, model_size_info):
super(DSCNN, self).__init__()
# N C H W
label_count = model_settings['label_count']
input_frequency_size = model_settings['dct_coefficient_count']
input_time_size = model_settings['spectrogram_length']
t_dim = input_time_size
f_dim = input_frequency_size
num_layers = model_size_info[0]
conv_feat = [None] * num_layers
conv_kt = [None] * num_layers
conv_kf = [None] * num_layers
conv_st = [None] * num_layers
conv_sf = [None] * num_layers
i = 1
for layer_no in range(0, num_layers):
conv_feat[layer_no] = model_size_info[i]
i += 1
conv_kt[layer_no] = model_size_info[i]
i += 1
conv_kf[layer_no] = model_size_info[i]
i += 1
conv_st[layer_no] = model_size_info[i]
i += 1
conv_sf[layer_no] = model_size_info[i]
i += 1
seq_cell = []
in_channel = 1
for layer_no in range(0, num_layers):
if layer_no == 0:
seq_cell.append(nn.Conv2d(in_channels=in_channel, out_channels=conv_feat[layer_no],
kernel_size=(conv_kt[layer_no], conv_kf[layer_no]),
stride=(conv_st[layer_no], conv_sf[layer_no]),
pad_mode="same", padding=0, has_bias=False))
seq_cell.append(nn.BatchNorm2d(num_features=conv_feat[layer_no], momentum=0.98))
in_channel = conv_feat[layer_no]
else:
seq_cell.append(DepthWiseConv(in_planes=in_channel, kernel_size=(conv_kt[layer_no], conv_kf[layer_no]),
stride=(conv_st[layer_no], conv_sf[layer_no]), pad_mode='same', pad=0))
seq_cell.append(nn.BatchNorm2d(num_features=in_channel, momentum=0.98))
seq_cell.append(nn.ReLU())
seq_cell.append(nn.Conv2d(in_channels=in_channel, out_channels=conv_feat[layer_no], kernel_size=(1, 1),
pad_mode="same"))
seq_cell.append(nn.BatchNorm2d(num_features=conv_feat[layer_no], momentum=0.98))
seq_cell.append(nn.ReLU())
in_channel = conv_feat[layer_no]
t_dim = math.ceil(t_dim / float(conv_st[layer_no]))
f_dim = math.ceil(f_dim / float(conv_sf[layer_no]))
seq_cell.append(nn.AvgPool2d(kernel_size=(t_dim, f_dim))) # to fix ?
seq_cell.append(nn.Flatten())
seq_cell.append(nn.Dropout(model_settings['dropout1']))
seq_cell.append(nn.Dense(in_channel, label_count))
self.model = nn.SequentialCell(seq_cell)
def construct(self, x):
x = self.model(x)
return x

View File

@ -0,0 +1,71 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Logger."""
import os
import sys
import logging
from datetime import datetime
logger_name_1 = 'ds-cnn'
class LOGGER(logging.Logger):
'''Build logger.'''
def __init__(self, logger_name):
super(LOGGER, self).__init__(logger_name)
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir):
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
line_width = 2
important_msg = '\n'
important_msg += ('*' * 70 + '\n') * line_width
important_msg += ('*' * line_width + '\n') * 2
important_msg += '*' * line_width + ' ' * 8 + msg + '\n'
important_msg += ('*' * line_width + '\n') * 2
important_msg += ('*' * 70 + '\n') * line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path):
logger = LOGGER(logger_name_1)
logger.setup_logging_file(path)
return logger

View File

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DSCNN loss."""
import mindspore.nn as nn
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
class CrossEntropy(_Loss):
'''Build CrossEntropy Loss.'''
def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes -1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
one_hot_label = self.onehot(label,
F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss

View File

@ -0,0 +1,723 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Learning rate scheduler."""
import math
from collections import Counter
import numpy as np
__all__ = ["LambdaLR", "MultiplicativeLR", "StepLR", "MultiStepLR", "ExponentialLR", "CosineAnnealingLR", "CyclicLR",
"CosineAnnealingWarmRestarts", "OneCycleLR", "POLYLR"]
class _WarmUp():
"""
Basic class for warm up
"""
def __init__(self, warmup_init_lr):
self.warmup_init_lr = warmup_init_lr
def get_lr(self, current_step=None):
# Get learning rate during warmup
raise NotImplementedError
class _LinearWarmUp(_WarmUp):
"""
Class for linear warm up
"""
def __init__(self, lr, warmup_epochs, steps_per_epoch, warmup_init_lr=0):
self.base_lr = lr
self.warmup_init_lr = warmup_init_lr
self.warmup_steps = int(warmup_epochs * steps_per_epoch)
super(_LinearWarmUp, self).__init__(warmup_init_lr)
def get_warmup_steps(self):
return self.warmup_steps
def get_lr(self, current_step=None):
lr_inc = (float(self.base_lr) - float(self.warmup_init_lr)) / float(self.warmup_steps)
lr = float(self.warmup_init_lr) + lr_inc * current_step
return lr
class _ConstWarmUp(_WarmUp):
"""
Class for const warm up
"""
def get_lr(self, current_step=None):
return self.warmup_init_lr
class _LRScheduler():
"""
Basic class for learning rate scheduler
"""
def __init__(self, lr, max_epoch, steps_per_epoch):
self.base_lr = lr
self.steps_per_epoch = steps_per_epoch
self.total_steps = int(max_epoch * steps_per_epoch)
def get_lr(self):
# Compute learning rate using chainable form of the scheduler
raise NotImplementedError
class LambdaLR(_LRScheduler):
r"""
Lambda learning rate scheduler
Sets the learning rate to the initial lr times a given function.
Args:
lr (float): Initial learning rate which is the lower boundary in the cycle.
steps_per_epoch (int): The number of steps per epoch to train for.
max_epoch (int): The number of epochs to train for.
lr_lambda (func. or list): A function which computes a multiplicative factor given an integer parameter epoch.
warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0
Outputs:
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)
Example:
>>> lambda1 = lambda epoch: epoch // 30
>>> scheduler = LambdaLR(lr=0.1, lr_lambda=lambda1, steps_per_epoch=5000, max_epoch=90)
>>> lr = scheduler.get_lr()
"""
def __init__(self, lr, lr_lambda, steps_per_epoch, max_epoch, warmup_epochs=0):
self.lr_lambda = lr_lambda
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
super(LambdaLR, self).__init__(lr, max_epoch, steps_per_epoch)
def get_lr(self):
warmup_steps = self.warmup.get_warmup_steps()
lr_each_step = []
for i in range(self.total_steps):
if i < warmup_steps:
lr = self.warmup.get_lr(i + 1)
else:
cur_ep = i // self.steps_per_epoch
lr = self.base_lr * self.lr_lambda(cur_ep)
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
class MultiplicativeLR(_LRScheduler):
"""
Multiplicative learning rate scheduler
Multiply the learning rate by the factor given in the specified function.
Args:
lr (float): Initial learning rate which is the lower boundary in the cycle.
steps_per_epoch (int): The number of steps per epoch to train for.
max_epoch (int): The number of epochs to train for.
lr_lambda (func. or list): A function which computes a multiplicative factor given an integer parameter epoch.
warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0
Outputs:
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)
Example:
>>> lmbda = lambda epoch: 0.95
>>> scheduler = MultiplicativeLR(lr=0.1, lr_lambda=lambda1, steps_per_epoch=5000, max_epoch=90)
>>> lr = scheduler.get_lr()
"""
def __init__(self, lr, lr_lambda, steps_per_epoch, max_epoch, warmup_epochs=0):
self.lr_lambda = lr_lambda
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
super(MultiplicativeLR, self).__init__(lr, max_epoch, steps_per_epoch)
def get_lr(self):
warmup_steps = self.warmup.get_warmup_steps()
lr_each_step = []
current_lr = self.base_lr
for i in range(self.total_steps):
if i < warmup_steps:
lr = self.warmup.get_lr(i + 1)
else:
cur_ep = i // self.steps_per_epoch
if i % self.steps_per_epoch == 0 and cur_ep > 0:
current_lr = current_lr * self.lr_lambda(cur_ep)
lr = current_lr
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
class StepLR(_LRScheduler):
"""
Step learning rate scheduler
Decays the learning rate by gamma every epoch_size epochs.
Args:
lr (float): Initial learning rate which is the lower boundary in the cycle.
epoch_size (int): Period of learning rate decay.
gamma (float): Multiplicative factor of learning rate decay.
steps_per_epoch (int): The number of steps per epoch to train for.
max_epoch (int): The number of epochs to train for.
warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0
Outputs:
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)
Example:
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 60
>>> # lr = 0.0005 if 60 <= epoch < 90
>>> # ...
>>> scheduler = StepLR(lr=0.1, epoch_size=30, gamma=0.1, steps_per_epoch=5000, max_epoch=90)
>>> lr = scheduler.get_lr()
"""
def __init__(self, lr, epoch_size, gamma, steps_per_epoch, max_epoch, warmup_epochs=0):
self.epoch_size = epoch_size
self.gamma = gamma
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
super(StepLR, self).__init__(lr, max_epoch, steps_per_epoch)
def get_lr(self):
warmup_steps = self.warmup.get_warmup_steps()
lr_each_step = []
for i in range(self.total_steps):
if i < warmup_steps:
lr = self.warmup.get_lr(i + 1)
else:
cur_ep = i // self.steps_per_epoch
lr = self.base_lr * self.gamma ** (cur_ep // self.epoch_size)
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
class POLYLR(_LRScheduler):
'''POLY LR.'''
def __init__(self, lr, steps_per_epoch, max_epoch, end_lr, power):
self.end_lr = end_lr
self.power = power
self.max_epoch = max_epoch
self.lr = lr
self.end_lr = end_lr
super(POLYLR, self).__init__(lr, max_epoch, steps_per_epoch)
def get_lr(self):
lr_each_step = []
total_steps = self.steps_per_epoch * self.max_epoch
for i in range(total_steps):
step_ = min(i, total_steps)
lr_each_step.append((self.lr - self.end_lr) * ((1.0 - step_ / total_steps) ** self.power) + self.end_lr)
print("lr_each_step:", lr_each_step[-1])
return np.array(lr_each_step).astype(np.float32)
class MultiStepLR(_LRScheduler):
"""
Multi-step learning rate scheduler
Decays the learning rate by gamma once the number of epoch reaches one of the milestones.
Args:
lr (float): Initial learning rate which is the lower boundary in the cycle.
milestones (list): List of epoch indices. Must be increasing.
gamma (float): Multiplicative factor of learning rate decay.
steps_per_epoch (int): The number of steps per epoch to train for.
max_epoch (int): The number of epochs to train for.
warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0
Outputs:
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)
Example:
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 80
>>> # lr = 0.0005 if epoch >= 80
>>> scheduler = MultiStepLR(lr=0.1, milestones=[30,80], gamma=0.1, steps_per_epoch=5000, max_epoch=90)
>>> lr = scheduler.get_lr()
"""
def __init__(self, lr, milestones, gamma, steps_per_epoch, max_epoch, warmup_epochs=0):
self.milestones = Counter(milestones)
self.gamma = gamma
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
super(MultiStepLR, self).__init__(lr, max_epoch, steps_per_epoch)
def get_lr(self):
warmup_steps = self.warmup.get_warmup_steps()
lr_each_step = []
current_lr = self.base_lr
for i in range(self.total_steps):
if i < warmup_steps:
lr = self.warmup.get_lr(i + 1)
else:
cur_ep = i // self.steps_per_epoch
if i % self.steps_per_epoch == 0 and cur_ep in self.milestones:
current_lr = current_lr * self.gamma
lr = current_lr
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
class ExponentialLR(_LRScheduler):
"""
Exponential learning rate scheduler
Decays the learning rate of each parameter group by gamma every epoch.
Args:
lr (float): Initial learning rate which is the lower boundary in the cycle.
gamma (float): Multiplicative factor of learning rate decay.
steps_per_epoch (int): The number of steps per epoch to train for.
max_epoch (int): The number of epochs to train for.
warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0
Outputs:
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)
Example:
>>> scheduler = ExponentialLR(lr=0.1, gamma=0.1, steps_per_epoch=5000, max_epoch=90)
>>> lr = scheduler.get_lr()
"""
def __init__(self, lr, gamma, steps_per_epoch, max_epoch, warmup_epochs=0):
self.gamma = gamma
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
super(ExponentialLR, self).__init__(lr, max_epoch, steps_per_epoch)
def get_lr(self):
warmup_steps = self.warmup.get_warmup_steps()
lr_each_step = []
current_lr = self.base_lr
for i in range(self.total_steps):
if i < warmup_steps:
lr = self.warmup.get_lr(i + 1)
else:
if i % self.steps_per_epoch == 0 and i > 0:
current_lr = current_lr * self.gamma
lr = current_lr
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
class CosineAnnealingLR(_LRScheduler):
r"""
Cosine annealing scheduler
Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}`
is set to the initial lr and :math:`T_{cur}` is the number of epochs since the
last restart in SGDR:
.. math::
\begin{aligned}
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
& T_{cur} \neq (2k+1)T_{max}; \\
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
& T_{cur} = (2k+1)T_{max}.
\end{aligned}
It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
Note:
This only implements the cosine annealing part of SGDR, and not the restarts.
Args:
lr (float): Initial learning rate which is the lower boundary in the cycle.
T_max (int): Maximum number of iterations.
steps_per_epoch (int): The number of steps per epoch to train for.
max_epoch (int): The number of epochs to train for.
warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0
eta_min (float, optional): Minimum learning rate. Default: 0.
Outputs:
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)
Example:
>>> scheduler = CosineAnnealingLR(lr=0.1, T_max=120, steps_per_epoch=5000, max_epoch=90)
>>> lr = scheduler.get_lr()
"""
def __init__(self, lr, T_max, steps_per_epoch, max_epoch, warmup_epochs=0, eta_min=0):
self.t_max = T_max
self.eta_min = eta_min
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
super(CosineAnnealingLR, self).__init__(lr, max_epoch, steps_per_epoch)
def get_lr(self):
warmup_steps = self.warmup.get_warmup_steps()
lr_each_step = []
current_lr = self.base_lr
for i in range(self.total_steps):
if i < warmup_steps:
lr = self.warmup.get_lr(i + 1)
else:
cur_ep = i // self.steps_per_epoch
if i % self.steps_per_epoch == 0 and i > 0:
current_lr = self.eta_min + (self.base_lr - self.eta_min) * \
(1. + math.cos(math.pi * cur_ep / self.t_max)) / 2
lr = current_lr
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
class CyclicLR(_LRScheduler):
r"""
Cyclical learning rate scheduler
Sets the learning rate according to cyclical learning rate policy (CLR).
The policy cycles the learning rate between two boundaries with a constant
frequency, as detailed in the paper `Cyclical Learning Rates for Training
Neural Networks`_. The distance between the two boundaries can be scaled on
a per-iteration or per-cycle basis.
Cyclical learning rate policy changes the learning rate after every batch.
This class has three built-in policies, as put forth in the paper:
* "triangular": A basic triangular cycle without amplitude scaling.
* "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
* "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
at each cycle iteration.
This implementation was adapted from the github repo: `bckenstler/CLR`_
.. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
.. _bckenstler/CLR: https://github.com/bckenstler/CLR
Args:
lr (float): Initial learning rate which is the lower boundary in the cycle.
max_lr (float): Upper learning rate boundaries in the cycle.
steps_per_epoch (int): The number of steps per epoch to train for.
max_epoch (int): The number of epochs to train for.
step_size_up (int): Number of training iterations in the
increasing half of a cycle.
Default: 2000
step_size_down (int): Number of training iterations in the
decreasing half of a cycle. If step_size_down is None,
it is set to step_size_up.
Default: None
mode (str): One of {triangular, triangular2, exp_range}.
Values correspond to policies detailed above.
If scale_fn is not None, this argument is ignored.
Default: 'triangular'
gamma (float): Constant in 'exp_range' scaling function: gamma**(cycle iterations)
Default: 1.0
scale_fn (function): Custom scaling policy defined by a single argument lambda function, where
0 <= scale_fn(x) <= 1 for all x >= 0. If specified, then 'mode' is ignored.
Default: None
scale_mode (str): {'cycle', 'iterations'}. Defines whether scale_fn is evaluated on
cycle number or cycle iterations (training iterations since start of cycle).
Default: 'cycle'
warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0
Outputs:
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)
Example:
>>> scheduler = CyclicLR(lr=0.1, max_lr=1.0, steps_per_epoch=5000, max_epoch=90)
>>> lr = scheduler.get_lr()
"""
def __init__(self,
lr,
max_lr,
steps_per_epoch,
max_epoch,
step_size_up=2000,
step_size_down=None,
mode='triangular',
gamma=1.,
scale_fn=None,
scale_mode='cycle',
warmup_epochs=0):
self.max_lr = max_lr
step_size_up = float(step_size_up)
step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
self.total_size = step_size_up + step_size_down
self.step_ratio = step_size_up / self.total_size
if mode not in ['triangular', 'triangular2', 'exp_range'] \
and scale_fn is None:
raise ValueError('mode is invalid and scale_fn is None')
self.mode = mode
self.gamma = gamma
if scale_fn is None:
if self.mode == 'triangular':
self.scale_fn = self._triangular_scale_fn
self.scale_mode = 'cycle'
elif self.mode == 'triangular2':
self.scale_fn = self._triangular2_scale_fn
self.scale_mode = 'cycle'
elif self.mode == 'exp_range':
self.scale_fn = self._exp_range_scale_fn
self.scale_mode = 'iterations'
else:
self.scale_fn = scale_fn
self.scale_mode = scale_mode
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
super(CyclicLR, self).__init__(lr, max_epoch, steps_per_epoch)
def _triangular_scale_fn(self, x):
if x is None:
pass
return 1.
def _triangular2_scale_fn(self, x):
return 1 / (2. ** (x - 1))
def _exp_range_scale_fn(self, x):
return self.gamma ** (x)
def get_lr(self):
warmup_steps = self.warmup.get_warmup_steps()
lr_each_step = []
for i in range(self.total_steps):
if i < warmup_steps:
lr = self.warmup.get_lr(i + 1)
else:
# Calculates the learning rate at batch index.
cycle = math.floor(1 + i / self.total_size)
x = 1. + i / self.total_size - cycle
if x <= self.step_ratio:
scale_factor = x / self.step_ratio
else:
scale_factor = (x - 1) / (self.step_ratio - 1)
base_height = (self.max_lr - self.base_lr) * scale_factor
if self.scale_mode == 'cycle':
lr = self.base_lr + base_height * self.scale_fn(cycle)
else:
lr = self.base_lr + base_height * self.scale_fn(i)
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
class CosineAnnealingWarmRestarts(_LRScheduler):
r"""
Cosine annealing scheduler with warm restarts
Set the learning rate using a cosine annealing schedule, where
:math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` is the
number of epochs since the last restart and :math:`T_{i}` is the number
of epochs between two warm restarts in SGDR:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
Args:
warmup_epochs (int): The number of epochs to Warmup.
Default: 0
Args:
lr (float): Initial learning rate which is the lower boundary in the cycle.
steps_per_epoch (int): The number of steps per epoch to train for.
max_epoch (int): The number of epochs to train for.
T_0 (int): Number of iterations for the first restart.
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
eta_min (float, optional): Minimum learning rate. Default: 0.
warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0
Outputs:
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)
Example:
>>> scheduler = CosineAnnealingWarmRestarts(lr=0.1, steps_per_epoch=5000, max_epoch=90, T_0=2)
>>> lr = scheduler.get_lr()
"""
def __init__(self, lr, steps_per_epoch, max_epoch, T_0, T_mult=1, eta_min=0, warmup_epochs=0):
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
if T_mult < 1 or not isinstance(T_mult, int):
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
self.t_0 = T_0
self.t_i = T_0
self.t_mult = T_mult
self.eta_min = eta_min
self.t_cur = 0
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
super(CosineAnnealingWarmRestarts, self).__init__(lr, max_epoch, steps_per_epoch)
def get_lr(self):
warmup_steps = self.warmup.get_warmup_steps()
lr_each_step = []
for i in range(self.total_steps):
if i < warmup_steps:
lr = self.warmup.get_lr(i + 1)
else:
if i % self.steps_per_epoch == 0 and i > 0:
self.t_cur += 1
if self.t_cur >= self.t_i:
self.t_cur = self.t_cur - self.t_i
self.t_i = self.t_i * self.t_mult
lr = self.eta_min + (self.base_lr - self.eta_min) * \
(1 + math.cos(math.pi * self.t_cur / self.t_i)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
class OneCycleLR(_LRScheduler):
r"""
One cycle learning rate scheduler
Sets the learning rate of each parameter group according to the
1cycle learning rate policy. The 1cycle policy anneals the learning
rate from an initial learning rate to some maximum learning rate and then
from that maximum learning rate to some minimum learning rate much lower
than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every batch.
This scheduler is not chainable.
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
https://arxiv.org/abs/1708.07120
Args:
lr (float): Initial learning rate which is the lower boundary in the cycle.
steps_per_epoch (int): The number of steps per epoch to train for.
max_epoch (int): The number of epochs to train for.
pct_start (float): The percentage of the cycle (in number of steps) spent
increasing the learning rate.
Default: 0.3
anneal_strategy (str): {'cos', 'linear'}
Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
linear annealing.
Default: 'cos'
div_factor (float): Determines the max learning rate via
:math:`max_lr = lr * div_factor`
Default: 25
final_div_factor (float): Determines the minimum learning rate via
:math:`min_lr = lr / final_div_factor`
Default: 1e4
warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0
Outputs:
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)
Example:
>>> scheduler = OneCycleLR(lr=0.1, steps_per_epoch=5000, max_epoch=90)
>>> lr = scheduler.get_lr()
"""
def __init__(self,
lr,
steps_per_epoch,
max_epoch,
pct_start=0.3,
anneal_strategy='cos',
div_factor=25.,
final_div_factor=1e4,
warmup_epochs=0):
self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)
super(OneCycleLR, self).__init__(lr, max_epoch, steps_per_epoch)
self.step_size_up = float(pct_start * self.total_steps) - 1
self.step_size_down = float(self.total_steps - self.step_size_up) - 1
# Validate pct_start
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))
# Validate anneal_strategy
if anneal_strategy not in ['cos', 'linear']:
raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
if anneal_strategy == 'cos':
self.anneal_func = self._annealing_cos
elif anneal_strategy == 'linear':
self.anneal_func = self._annealing_linear
# Initialize learning rate variables
self.max_lr = lr * div_factor
self.min_lr = lr / final_div_factor
def _annealing_cos(self, start, end, pct):
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
def _annealing_linear(self, start, end, pct):
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
return (end - start) * pct + start
def get_lr(self):
warmup_steps = self.warmup.get_warmup_steps()
lr_each_step = []
for i in range(self.total_steps):
if i < warmup_steps:
lr = self.warmup.get_lr(i + 1)
else:
if i <= self.step_size_up:
lr = self.anneal_func(self.base_lr, self.max_lr, i / self.step_size_up)
else:
down_step_num = i - self.step_size_up
lr = self.anneal_func(self.max_lr, self.min_lr, down_step_num / self.step_size_down)
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)

View File

@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DSCNN models."""
from mindspore.train.serialization import load_checkpoint, load_param_into_net
def load_ckpt(network, pretrain_ckpt_path, trainable=True):
"""
incremental_learning or not
"""
param_dict = load_checkpoint(pretrain_ckpt_path)
load_param_into_net(network, param_dict)
if not trainable:
for param in network.get_parameters():
param.requires_grad = False

View File

@ -0,0 +1,74 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Logger."""
import os
import sys
import logging
from datetime import datetime
logger_name_1 = 'ds-cnn'
class LOGGER(logging.Logger):
'''Build logger.'''
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
'''Setup logging file.'''
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*' * 70 + '\n') * line_width
important_msg += ('*' * line_width + '\n') * 2
important_msg += '*' * line_width + ' ' * 8 + msg + '\n'
important_msg += ('*' * line_width + '\n') * 2
important_msg += ('*' * 70 + '\n') * line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
logger = LOGGER(logger_name_1, rank)
logger.setup_logging_file(path, rank)
return logger

View File

@ -0,0 +1,19 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Utils."""
SILENCE_LABEL = '_silence_'
UNKNOWN_WORD_LABEL = '_unknown_'
def prepare_words_list(wanted_words):
return [SILENCE_LABEL, UNKNOWN_WORD_LABEL] + wanted_words

View File

@ -0,0 +1,150 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DSCNN train."""
import os
import datetime
import argparse
import numpy as np
from mindspore import context
from mindspore import Tensor, Model
from mindspore.nn.optim import Momentum
from mindspore.common import dtype as mstype
from mindspore.train.serialization import load_checkpoint
from src.config import train_config
from src.log import get_logger
from src.dataset import audio_dataset
from src.ds_cnn import DSCNN
from src.loss import CrossEntropy
from src.lr_scheduler import MultiStepLR, CosineAnnealingLR
from src.callback import ProgressMonitor, callback_func
def get_top5_acc(top5_arg, gt_class):
sub_count = 0
for top5, gt in zip(top5_arg, gt_class):
if gt in top5:
sub_count += 1
return sub_count
def val(args, model, val_dataset):
'''Eval.'''
val_dataloader = val_dataset.create_tuple_iterator()
img_tot = 0
top1_correct = 0
top5_correct = 0
for data, gt_classes in val_dataloader:
output = model.predict(Tensor(data, mstype.float32))
output = output.asnumpy()
top1_output = np.argmax(output, (-1))
top5_output = np.argsort(output)[:, -5:]
gt_classes = gt_classes.asnumpy()
t1_correct = np.equal(top1_output, gt_classes).sum()
top1_correct += t1_correct
top5_correct += get_top5_acc(top5_output, gt_classes)
img_tot += output.shape[0]
results = [[top1_correct], [top5_correct], [img_tot]]
results = np.array(results)
top1_correct = results[0, 0]
top5_correct = results[1, 0]
img_tot = results[2, 0]
acc1 = 100.0 * top1_correct / img_tot
acc5 = 100.0 * top5_correct / img_tot
if acc1 > args.best_acc:
args.best_acc = acc1
args.best_epoch = args.epoch_cnt - 1
args.logger.info('Eval: top1_cor:{}, top5_cor:{}, tot:{}, acc@1={:.2f}%, acc@5={:.2f}%' \
.format(top1_correct, top5_correct, img_tot, acc1, acc5))
def trainval(args, model, train_dataset, val_dataset, cb):
callbacks = callback_func(args, cb, 'epoch{}'.format(args.epoch_cnt))
model.train(args.val_interval, train_dataset, callbacks=callbacks, dataset_sink_mode=args.dataset_sink_mode)
val(args, model, val_dataset)
def train():
'''Train.'''
parser = argparse.ArgumentParser()
parser.add_argument('--device_id', type=int, default=1, help='which device the model will be trained on')
args, model_settings = train_config(parser)
context.set_context(mode=context.GRAPH_MODE, device_id=args.device_id, enable_auto_mixed_precision=True)
args.rank_save_ckpt_flag = 1
# Logger
args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir)
# Dataloader: train, val
train_dataset = audio_dataset(args.feat_dir, 'training', model_settings['spectrogram_length'],
model_settings['dct_coefficient_count'], args.per_batch_size)
args.steps_per_epoch = train_dataset.get_dataset_size()
val_dataset = audio_dataset(args.feat_dir, 'validation', model_settings['spectrogram_length'],
model_settings['dct_coefficient_count'], args.per_batch_size)
# show args
args.logger.save_args(args)
# Network
args.logger.important_info('start create network')
network = DSCNN(model_settings, args.model_size_info)
# Load pretrain model
if os.path.isfile(args.pretrained):
load_checkpoint(args.pretrained, network)
args.logger.info('load model {} success'.format(args.pretrained))
# Loss
criterion = CrossEntropy(num_classes=model_settings['label_count'])
# LR scheduler
if args.lr_scheduler == 'multistep':
lr_scheduler = MultiStepLR(args.lr, args.lr_epochs, args.lr_gamma, args.steps_per_epoch,
args.max_epoch, warmup_epochs=args.warmup_epochs)
elif args.lr_scheduler == 'cosine_annealing':
lr_scheduler = CosineAnnealingLR(args.lr, args.T_max, args.steps_per_epoch, args.max_epoch,
warmup_epochs=args.warmup_epochs, eta_min=args.eta_min)
else:
raise NotImplementedError(args.lr_scheduler)
lr_schedule = lr_scheduler.get_lr()
# Optimizer
opt = Momentum(params=network.trainable_params(),
learning_rate=Tensor(lr_schedule),
momentum=args.momentum,
weight_decay=args.weight_decay)
model = Model(network, loss_fn=criterion, optimizer=opt, amp_level='O0')
# Training
args.epoch_cnt = 0
args.best_epoch = 0
args.best_acc = 0
progress_cb = ProgressMonitor(args)
while args.epoch_cnt + args.val_interval < args.max_epoch:
trainval(args, model, train_dataset, val_dataset, progress_cb)
rest_ep = args.max_epoch - args.epoch_cnt
if rest_ep > 0:
trainval(args, model, train_dataset, val_dataset, progress_cb)
args.logger.info('Best epoch:{} acc:{:.2f}%'.format(args.best_epoch, args.best_acc))
if __name__ == "__main__":
train()