forked from mindspore-Ecosystem/mindspore
113 lines
3.9 KiB
Python
113 lines
3.9 KiB
Python
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ===========================================================================
|
|
"""DSCNN eval."""
|
|
import os
|
|
import datetime
|
|
import glob
|
|
import argparse
|
|
|
|
import numpy as np
|
|
from mindspore import context
|
|
from mindspore import Tensor, Model
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from src.config import eval_config
|
|
from src.log import get_logger
|
|
from src.dataset import audio_dataset
|
|
from src.ds_cnn import DSCNN
|
|
from src.models import load_ckpt
|
|
|
|
def get_top5_acc(top5_arg, gt_class):
|
|
sub_count = 0
|
|
for top5, gt in zip(top5_arg, gt_class):
|
|
if gt in top5:
|
|
sub_count += 1
|
|
return sub_count
|
|
|
|
|
|
def val(args, model, test_de):
|
|
'''Eval.'''
|
|
eval_dataloader = test_de.create_tuple_iterator()
|
|
img_tot = 0
|
|
top1_correct = 0
|
|
top5_correct = 0
|
|
for data, gt_classes in eval_dataloader:
|
|
output = model.predict(Tensor(data, mstype.float32))
|
|
output = output.asnumpy()
|
|
top1_output = np.argmax(output, (-1))
|
|
top5_output = np.argsort(output)[:, -5:]
|
|
gt_classes = gt_classes.asnumpy()
|
|
t1_correct = np.equal(top1_output, gt_classes).sum()
|
|
top1_correct += t1_correct
|
|
top5_correct += get_top5_acc(top5_output, gt_classes)
|
|
img_tot += output.shape[0]
|
|
|
|
results = [[top1_correct], [top5_correct], [img_tot]]
|
|
|
|
results = np.array(results)
|
|
|
|
top1_correct = results[0, 0]
|
|
top5_correct = results[1, 0]
|
|
img_tot = results[2, 0]
|
|
acc1 = 100.0 * top1_correct / img_tot
|
|
acc5 = 100.0 * top5_correct / img_tot
|
|
if acc1 > args.best_acc:
|
|
args.best_acc = acc1
|
|
args.best_index = args.index
|
|
args.logger.info('Eval: top1_cor:{}, top5_cor:{}, tot:{}, acc@1={:.2f}%, acc@5={:.2f}%' \
|
|
.format(top1_correct, top5_correct, img_tot, acc1, acc5))
|
|
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--device_id', type=int, default=1, help='which device the model will be trained on')
|
|
args, model_settings = eval_config(parser)
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci", device_id=args.device_id)
|
|
|
|
# Logger
|
|
args.outputs_dir = os.path.join(args.log_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
|
args.logger = get_logger(args.outputs_dir)
|
|
# show args
|
|
args.logger.save_args(args)
|
|
# find model path
|
|
if os.path.isdir(args.model_dir):
|
|
models = list(glob.glob(os.path.join(args.model_dir, '*.ckpt')))
|
|
print(models)
|
|
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[0].split('epoch')[-1])
|
|
args.models = sorted(models, key=f)
|
|
else:
|
|
args.models = [args.model_dir]
|
|
|
|
args.best_acc = 0
|
|
args.index = 0
|
|
args.best_index = 0
|
|
for model_path in args.models:
|
|
test_de = audio_dataset(args.feat_dir, 'testing', model_settings['spectrogram_length'],
|
|
model_settings['dct_coefficient_count'], args.per_batch_size)
|
|
network = DSCNN(model_settings, args.model_size_info)
|
|
|
|
load_ckpt(network, model_path, False)
|
|
network.set_train(False)
|
|
model = Model(network)
|
|
args.logger.info('load model {} success'.format(model_path))
|
|
val(args, model, test_de)
|
|
args.index += 1
|
|
|
|
args.logger.info('Best model:{} acc:{:.2f}%'.format(args.models[args.best_index], args.best_acc))
|
|
|
|
if __name__ == "__main__":
|
|
main()
|