!15983 Dscnn implementation on GPU

From: @charlie__chen
Reviewed-by: @limingqi107,@cristoval
Signed-off-by: @cristoval
This commit is contained in:
mindspore-ci-bot 2021-05-10 17:18:17 +08:00 committed by Gitee
commit f10c015be4
8 changed files with 158 additions and 47 deletions

View File

@ -94,6 +94,20 @@ First set the config for data, train, eval in src/config.py
python eval.py --model_dir your_models_folder_path
```
- running on GPU
```python
# run training example
python train.py --amp_level 'O3' --device_target='GPU'
# run evaluation example
# if you want to eval a specific model, you should specify model_dir to the ckpt path:
python eval.py --device_id 0 --model_dir your_ckpt_path --device_target 'GPU'
# if you want to eval all the model you saved, you should specify model_dir to the folder where the models are saved.
python eval.py --device_id 0 --model_dir your_models_folder_path --device_target 'GPU'
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
@ -105,6 +119,8 @@ First set the config for data, train, eval in src/config.py
│ ├──run_download_process_data.sh // shell script for download dataset and prepare feature and label
│ ├──run_train_ascend.sh // shell script for train on ascend
│ ├──run_eval_ascend.sh // shell script for evaluation on ascend
│ ├──run_train_gpu.sh // shell script for train on gpu
│ ├──run_eval_gpu.sh // shell script for evaluation on gpu
├── src
│ ├──callback.py // callbacks
│ ├──config.py // parameter configuration of data, train and eval
@ -173,6 +189,8 @@ Parameters for both training and evaluation can be set in config.py.
'log_interval': 100 # logging interval
'ckpt_path': 'train_outputs' # the location where checkpoint and log will be saved
'ckpt_interval': 100 # save ckpt_interval
'device_target': 'Ascend' # device target used to train or evaluate the dataset.
'amp_level': 'O3' # amp level for the mix precision training
```
- config for DS-CNN and evaluation parameters of Speech commands dataset version 1
@ -184,6 +202,7 @@ Parameters for both training and evaluation can be set in config.py.
# Words to use (others will be added to an unknown label)
'sample_rate': 16000 # Expected sample rate of the wavs
'device_id': 1000 # device ID used to train or evaluate the dataset.
'device_target': 'Ascend' # device target used to train or evaluate the dataset.
'clip_duration_ms': 10 # Expected duration in milliseconds of the wavs
'window_size_ms': 40.0 # How long each spectrogram timeslice is
'window_stride_ms': 20.0 # How long each spectrogram timeslice is
@ -227,6 +246,15 @@ Parameters for both training and evaluation can be set in config.py.
Best epoch:41 acc:93.73%
```
- running on GPU
for shell script:
```python
# sh scripts/run_train_gpu.sh [device_num] [device_id] [amp_level]
sh scripts/run_train_gpu.sh 1 0 'O3'
```
The checkpoints and log will be saved in the train_outputs.
## [Evaluation Process](#contents)
@ -255,6 +283,17 @@ Parameters for both training and evaluation can be set in config.py.
python eval.py --device_id 0 --model_dir train_outputs/*
```
- evaluation on Speech commands dataset version 1 when running on GPU
for shell scripts:
```bash
# sh scripts/run_eval_gpu.sh device_id model_dir
sh scripts/run_eval_gpu.sh 0 train_outputs/*/*.ckpt
or
sh scripts/run_eval_gpu.sh 0 train_outputs/*/
```
You can view the results on the screen or from logs in eval_outputs folder. The accuracy of the test dataset will be as follows:
```python
@ -268,39 +307,39 @@ Parameters for both training and evaluation can be set in config.py.
### Train Performance
| Parameters | Ascend |
| -------------------------- | ------------------------------------------------------------ |
| Model Version | DS-CNN |
| Resource | Ascend 910; CPU 2.60GHz, 56cores; Memory 314G; OS Euler2.8 |
| uploaded Date | 27/09/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | Speech commands dataset version 1 |
| Training Parameters | epoch=80, batch_size = 100, lr=0.1 |
| Optimizer | Momentum |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Loss | 0.0019 |
| Speed | 2s/epoch |
| Total time | 4 mins |
| Parameters (K) | 500K |
| Checkpoint for Fine tuning | 3.3M (.ckpt file) |
| Parameters | Ascend | GPU |
| -------------------------- | ------------------------------------------------------------ | -------------------------------------------------|
| Model Version | DS-CNN | DS-CNN |
| Resource | Ascend 910; CPU 2.60GHz, 56cores; Memory 314G; OS Euler2.8 | NV SMX2 V100-32G |
| uploaded Date | 27/09/2020 (month/day/year) | 05/05/2021 (month/day/year) |
| MindSpore Version | 1.0.0 | 1.2.0 |
| Dataset | Speech commands dataset version 1 | Speech commands dataset version 1 |
| Training Parameters | epoch=80, batch_size = 100, lr=0.1 | epoch=80, batch_size = 100, lr=0.1 |
| Optimizer | Momentum | Momentum |
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
| outputs | probability | probability |
| Loss | 0.0019 | 0.003304138 |
| Speed | 2s/epoch | 3s/epoch |
| Total time | 4 mins | 6 mins |
| Parameters (K) | 500K | 500K |
| Checkpoint for Fine tuning | 3.3M (.ckpt file) | 3.3M (.ckpt file) |
| Script | [Link]() | [Link]() |
### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | DS-CNN |
| Resource | Ascend 910; OS Euler2.8 |
| Uploaded Date | 09/27/2020 |
| MindSpore Version | 1.0.0 |
| Dataset |Speech commands dataset version 1 |
| Training Parameters | src/config.py |
| outputs | probability |
| Accuracy | 93.96% |
| Total time | 3min |
| Params (K) | 500K |
|Checkpoint for Fine tuning (M) | 3.3M |
| Parameters | Ascend | GPU |
| ------------------- | --------------------------- | -------------------------|
| Model Version | DS-CNN | DS-CNN |
| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G |
| Uploaded Date | 09/27/2020 | 05/05/2021 |
| MindSpore Version | 1.0.0 | 1.2.0 |
| Dataset |Speech commands dataset version 1 | Speech commands dataset version 1 |
| Training Parameters | src/config.py | src/config.py |
| outputs | probability | probability |
| Accuracy | 93.96% | 93.97% |
| Total time | 3min | 2min20s
| Params (K) | 500K | 500K |
|Checkpoint for Fine tuning (M) | 3.3M | 3.3M |
# [Description of Random Situation](#contents)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -74,8 +74,9 @@ def val(args, model, test_de):
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--device_id', type=int, default=1, help='which device the model will be trained on')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'])
args, model_settings = eval_config(parser)
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci", device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
# Logger
args.outputs_dir = os.path.join(args.log_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

View File

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
python eval.py --device_id $1 --model_dir $2 --device_target 'GPU' > eval.log 2>&1 &

View File

@ -0,0 +1,25 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
export DEVICE_NUM=$1
export RANK_SIZE=$1
export CUDA_VISIBLE_DEVICES="$2"
if [ $1 -gt 1 ]
then
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python train.py --is_distributed --amp_level $3 --device_target="GPU" > train.log 2>&1 &
else
python train.py --amp_level $3 --device_target='GPU' > train.log 2>&1 &
fi

View File

@ -16,6 +16,7 @@
import time
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import TimeMonitor
from mindspore.train.callback import CheckpointConfig, Callback
@ -85,4 +86,5 @@ def callback_func(args, cb, prefix):
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.outputs_dir, prefix=prefix)
callbacks.append(ckpt_cb)
callbacks.append(TimeMonitor(args.per_batch_size))
return callbacks

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -35,12 +35,13 @@ class NpyDataset():
return data.astype(np.float32), label.astype(np.int32)
def audio_dataset(data_dir, data_type, h, w, batch_size):
def audio_dataset(data_dir, data_type, h, w, batch_size, device_num=1, rank=0):
if 'testing' in data_dir:
shuffle = False
else:
shuffle = True
dataset = NpyDataset(data_dir, data_type, h, w)
de_dataset = de.GeneratorDataset(dataset, ["feats", "labels"], shuffle=shuffle)
de_dataset = de_dataset.batch(batch_size, drop_remainder=False)
de_dataset = de.GeneratorDataset(dataset, ["feats", "labels"], shuffle=shuffle,
num_shards=device_num, shard_id=rank)
de_dataset = de_dataset.batch(batch_size, drop_remainder=True)
return de_dataset

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -30,8 +30,8 @@ import soundfile as sf
from python_speech_features import mfcc
from tqdm import tqdm
from src.config import train_config, prepare_model_settings
from src.utils import prepare_words_list
from config import train_config, prepare_model_settings
from utils import prepare_words_list
FLAGS = None
MAX_NUM_WAVS_PER_CLASS = 2 ** 27 - 1 # ~134M

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,9 +20,11 @@ import argparse
import numpy as np
from mindspore import context
from mindspore import Tensor, Model
from mindspore.train.model import ParallelMode
from mindspore.nn.optim import Momentum
from mindspore.common import dtype as mstype
from mindspore.train.serialization import load_checkpoint
from mindspore.communication.management import init, get_rank, get_group_size
from src.config import train_config
from src.log import get_logger
@ -47,6 +49,12 @@ def val(args, model, val_dataset):
img_tot = 0
top1_correct = 0
top5_correct = 0
if args.amp_level == 'O0':
origin_mstype = mstype.float32
else:
origin_mstype = mstype.float16
model.predict_network.to_float(mstype.float32)
for data, gt_classes in val_dataloader:
output = model.predict(Tensor(data, mstype.float32))
output = output.asnumpy()
@ -58,6 +66,7 @@ def val(args, model, val_dataset):
top5_correct += get_top5_acc(top5_output, gt_classes)
img_tot += output.shape[0]
model.predict_network.to_float(origin_mstype)
results = [[top1_correct], [top5_correct], [img_tot]]
results = np.array(results)
@ -74,27 +83,44 @@ def val(args, model, val_dataset):
.format(top1_correct, top5_correct, img_tot, acc1, acc5))
def trainval(args, model, train_dataset, val_dataset, cb):
def trainval(args, model, train_dataset, val_dataset, cb, rank):
callbacks = callback_func(args, cb, 'epoch{}'.format(args.epoch_cnt))
model.train(args.val_interval, train_dataset, callbacks=callbacks, dataset_sink_mode=args.dataset_sink_mode)
val(args, model, val_dataset)
if rank == 0:
val(args, model, val_dataset)
def train():
'''Train.'''
parser = argparse.ArgumentParser()
parser.add_argument('--device_id', type=int, default=1, help='which device the model will be trained on')
parser.add_argument('--is_distributed', action='store_true', default=False, help='distributed training')
parser.add_argument('--device_id', type=int, default=0, help='which device the model will be trained on')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'])
parser.add_argument('--amp_level', type=str, default='O0', choices=['O3', 'O2', 'O0'])
args, model_settings = train_config(parser)
context.set_context(mode=context.GRAPH_MODE, device_id=args.device_id, enable_auto_mixed_precision=True)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_auto_mixed_precision=True)
args.rank_save_ckpt_flag = 1
# init distributed
if args.is_distributed:
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
init()
rank = get_rank()
group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True)
else:
rank = 0
group_size = 1
context.set_context(device_id=args.device_id)
# Logger
args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir)
# Dataloader: train, val
train_dataset = audio_dataset(args.feat_dir, 'training', model_settings['spectrogram_length'],
model_settings['dct_coefficient_count'], args.per_batch_size)
model_settings['dct_coefficient_count'], args.per_batch_size, group_size, rank)
args.steps_per_epoch = train_dataset.get_dataset_size()
val_dataset = audio_dataset(args.feat_dir, 'validation', model_settings['spectrogram_length'],
model_settings['dct_coefficient_count'], args.per_batch_size)
@ -131,7 +157,7 @@ def train():
momentum=args.momentum,
weight_decay=args.weight_decay)
model = Model(network, loss_fn=criterion, optimizer=opt, amp_level='O0')
model = Model(network, loss_fn=criterion, optimizer=opt, amp_level=args.amp_level, keep_batchnorm_fp32=False)
# Training
args.epoch_cnt = 0
@ -139,10 +165,10 @@ def train():
args.best_acc = 0
progress_cb = ProgressMonitor(args)
while args.epoch_cnt + args.val_interval < args.max_epoch:
trainval(args, model, train_dataset, val_dataset, progress_cb)
trainval(args, model, train_dataset, val_dataset, progress_cb, rank)
rest_ep = args.max_epoch - args.epoch_cnt
if rest_ep > 0:
trainval(args, model, train_dataset, val_dataset, progress_cb)
trainval(args, model, train_dataset, val_dataset, progress_cb, rank)
args.logger.info('Best epoch:{} acc:{:.2f}%'.format(args.best_epoch, args.best_acc))