138 lines
4.8 KiB
Python
138 lines
4.8 KiB
Python
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
'''
|
|
##############evaluate trained models#################
|
|
python eval.py
|
|
'''
|
|
|
|
import numpy as np
|
|
|
|
from src.model_utils.config import config
|
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
|
from src.model_utils.device_adapter import get_device_id
|
|
from src.musictagger import MusicTaggerCNN
|
|
from src.dataset import create_dataset
|
|
|
|
import mindspore.common.dtype as mstype
|
|
from mindspore import context
|
|
from mindspore import Tensor
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
|
def calculate_auc(labels_list, preds_list):
|
|
"""
|
|
The AUC calculation function
|
|
Input:
|
|
labels_list: list of true label
|
|
preds_list: list of predicted label
|
|
Outputs
|
|
Float, means of AUC
|
|
"""
|
|
auc = []
|
|
n_bins = labels_list.shape[0] // 2
|
|
if labels_list.ndim == 1:
|
|
labels_list = labels_list.reshape(-1, 1)
|
|
preds_list = preds_list.reshape(-1, 1)
|
|
for i in range(labels_list.shape[1]):
|
|
labels = labels_list[:, i]
|
|
preds = preds_list[:, i]
|
|
postive_len = labels.sum()
|
|
negative_len = labels.shape[0] - postive_len
|
|
total_case = postive_len * negative_len
|
|
positive_histogram = np.zeros((n_bins))
|
|
negative_histogram = np.zeros((n_bins))
|
|
bin_width = 1.0 / n_bins
|
|
|
|
for j, _ in enumerate(labels):
|
|
nth_bin = int(preds[j] // bin_width)
|
|
if labels[j]:
|
|
positive_histogram[nth_bin] = positive_histogram[nth_bin] + 1
|
|
else:
|
|
negative_histogram[nth_bin] = negative_histogram[nth_bin] + 1
|
|
|
|
accumulated_negative = 0
|
|
satisfied_pair = 0
|
|
for k in range(n_bins):
|
|
satisfied_pair += (
|
|
positive_histogram[k] * accumulated_negative +
|
|
positive_histogram[k] * negative_histogram[k] * 0.5)
|
|
accumulated_negative += negative_histogram[k]
|
|
auc.append(satisfied_pair / total_case)
|
|
|
|
return np.mean(auc)
|
|
|
|
|
|
def val(net, data_dir, filename, num_consumer=4, batch=32):
|
|
"""
|
|
Validation function, estimate the performance of trained model
|
|
|
|
Input:
|
|
net: the trained neural network
|
|
data_dir: path to the validation dataset
|
|
filename: name of the validation dataset
|
|
num_consumer: split number of validation dataset
|
|
batch: validation batch size
|
|
Outputs
|
|
Float, AUC
|
|
"""
|
|
data_train = create_dataset(data_dir, filename, 32, ['feature', 'label'],
|
|
num_consumer)
|
|
data_train = data_train.create_tuple_iterator()
|
|
res_pred = []
|
|
res_true = []
|
|
for data, label in data_train:
|
|
x = net(Tensor(data, dtype=mstype.float32))
|
|
res_pred.append(x.asnumpy())
|
|
res_true.append(label.asnumpy())
|
|
res_pred = np.concatenate(res_pred, axis=0)
|
|
res_true = np.concatenate(res_true, axis=0)
|
|
auc = calculate_auc(res_true, res_pred)
|
|
return auc
|
|
|
|
|
|
def validation(net, model_path, data_dir, filename, num_consumer, batch):
|
|
param_dict = load_checkpoint(model_path)
|
|
load_param_into_net(net, param_dict)
|
|
|
|
auc = val(net, data_dir, filename, num_consumer, batch)
|
|
return auc
|
|
|
|
|
|
def modelarts_process():
|
|
pass
|
|
|
|
@moxing_wrapper(pre_process=modelarts_process)
|
|
def fcn4_eval():
|
|
"""
|
|
eval network
|
|
"""
|
|
context.set_context(device_target=config.device_target, mode=context.GRAPH_MODE, device_id=get_device_id())
|
|
|
|
network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
|
|
kernel_size=[3, 3, 3, 3, 3],
|
|
padding=[0] * 5,
|
|
maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
|
|
has_bias=True)
|
|
network.set_train(False)
|
|
auc_val = validation(network, config.checkpoint_path + "/" + config.model_name, config.data_dir,
|
|
config.val_filename, config.num_consumer, config.batch_size)
|
|
|
|
print("=" * 10 + "Validation Performance" + "=" * 10)
|
|
print("AUC: {:.5f}".format(auc_val))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fcn4_eval()
|