mass eval mertric update.
This commit is contained in:
parent
1f4944fa15
commit
5ec9168306
|
@ -40,7 +40,7 @@ parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoi
|
|||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
|
||||
def FasterRcnn_eval(dataset_path, ckpt_path, ann_file):
|
||||
"""FasterRcnn evaluation."""
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
def bias_init_zeros(shape):
|
||||
"""Bias init method."""
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore import Tensor
|
|||
from mindspore import context
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Proposal(nn.Cell):
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore import context
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
def weight_init_ones(shape):
|
||||
|
|
|
@ -52,7 +52,7 @@ parser.add_argument("--device_num", type=int, default=1, help="Use device nums,
|
|||
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if not args_opt.do_eval and args_opt.run_distribute:
|
||||
|
|
|
@ -15,15 +15,13 @@
|
|||
"""Evaluation api."""
|
||||
import argparse
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from config import TransformerConfig
|
||||
from src.transformer import infer
|
||||
from src.utils import ngram_ppl
|
||||
from src.transformer import infer, infer_ppl
|
||||
from src.utils import Dictionary
|
||||
from src.utils import rouge
|
||||
from src.utils import get_score
|
||||
|
||||
parser = argparse.ArgumentParser(description='Evaluation MASS.')
|
||||
parser.add_argument("--config", type=str, required=True,
|
||||
|
@ -32,6 +30,8 @@ parser.add_argument("--vocab", type=str, required=True,
|
|||
help="Vocabulary to use.")
|
||||
parser.add_argument("--output", type=str, required=True,
|
||||
help="Result file path.")
|
||||
parser.add_argument("--metric", type=str, default='rouge',
|
||||
help='Set eval method.')
|
||||
|
||||
|
||||
def get_config(config):
|
||||
|
@ -45,31 +45,15 @@ if __name__ == '__main__':
|
|||
args, _ = parser.parse_known_args()
|
||||
vocab = Dictionary.load_from_persisted_dict(args.vocab)
|
||||
_config = get_config(args.config)
|
||||
result = infer(_config)
|
||||
|
||||
if args.metric == 'rouge':
|
||||
result = infer(_config)
|
||||
else:
|
||||
result = infer_ppl(_config)
|
||||
|
||||
with open(args.output, "wb") as f:
|
||||
pickle.dump(result, f, 1)
|
||||
|
||||
ppl_score = 0.
|
||||
preds = []
|
||||
tgts = []
|
||||
_count = 0
|
||||
for sample in result:
|
||||
sentence_prob = np.array(sample['prediction_prob'], dtype=np.float32)
|
||||
sentence_prob = sentence_prob[:, 1:]
|
||||
_ppl = []
|
||||
for path in sentence_prob:
|
||||
_ppl.append(ngram_ppl(path, log_softmax=True))
|
||||
ppl = np.min(_ppl)
|
||||
preds.append(' '.join([vocab[t] for t in sample['prediction']]))
|
||||
tgts.append(' '.join([vocab[t] for t in sample['target']]))
|
||||
print(f" | source: {' '.join([vocab[t] for t in sample['source']])}")
|
||||
print(f" | target: {tgts[-1]}")
|
||||
print(f" | prediction: {preds[-1]}")
|
||||
print(f" | ppl: {ppl}.")
|
||||
if np.isinf(ppl):
|
||||
continue
|
||||
ppl_score += ppl
|
||||
_count += 1
|
||||
|
||||
print(f" | PPL={ppl_score / _count}.")
|
||||
rouge(preds, tgts)
|
||||
# get score by given metric
|
||||
score = get_score(result, vocab, metric=args.metric)
|
||||
print(score)
|
||||
|
|
|
@ -18,7 +18,7 @@ export DEVICE_ID=0
|
|||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab: -- "$@"`
|
||||
options=`getopt -u -o ht:n:i:j:c:o:v:m: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab:,metric: -- "$@"`
|
||||
eval set -- "$options"
|
||||
echo $options
|
||||
|
||||
|
@ -35,6 +35,7 @@ echo_help()
|
|||
echo " -c --config set the configuration file"
|
||||
echo " -o --output set the output file of inference"
|
||||
echo " -v --vocab set the vocabulary"
|
||||
echo " -m --metric set the metric"
|
||||
}
|
||||
|
||||
set_hccl_json()
|
||||
|
@ -43,8 +44,8 @@ set_hccl_json()
|
|||
do
|
||||
if [[ "$1" == "-j" || "$1" == "--hccl_json" ]]
|
||||
then
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$2 #/data/wsc/hccl_2p_01.json
|
||||
export RANK_TABLE_FILE=$2 #/data/wsc/hccl_2p_01.json
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$2
|
||||
export RANK_TABLE_FILE=$2
|
||||
break
|
||||
fi
|
||||
shift
|
||||
|
@ -119,6 +120,11 @@ do
|
|||
vocab=$2
|
||||
shift 2
|
||||
;;
|
||||
-m|--metric)
|
||||
echo "metric";
|
||||
metric=$2
|
||||
shift 2
|
||||
;;
|
||||
--)
|
||||
shift
|
||||
break
|
||||
|
@ -163,7 +169,7 @@ do
|
|||
python train.py --config ${configurations##*/} >>log.log 2>&1 &
|
||||
elif [ "$task" == "infer" ]
|
||||
then
|
||||
python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} >>log_infer.log 2>&1 &
|
||||
python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} >>log_infer.log 2>&1 &
|
||||
fi
|
||||
cd ../
|
||||
done
|
||||
|
|
|
@ -19,10 +19,11 @@ from .decoder import TransformerDecoder
|
|||
from .beam_search import BeamSearchDecoder
|
||||
from .transformer_for_train import TransformerTraining, LabelSmoothedCrossEntropyCriterion, \
|
||||
TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell
|
||||
from .infer_mass import infer
|
||||
from .infer_mass import infer, infer_ppl
|
||||
|
||||
__all__ = [
|
||||
"infer",
|
||||
"infer_ppl",
|
||||
"TransformerTraining",
|
||||
"LabelSmoothedCrossEntropyCriterion",
|
||||
"TransformerTrainOneStepWithLossScaleCell",
|
||||
|
|
|
@ -41,7 +41,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
self.vocab_size = vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
|
||||
init_weight = np.random.normal(0, embed_dim ** -0.5, size=[vocab_size, embed_dim])
|
||||
init_weight = np.random.normal(0, embed_dim ** -0.5, size=[vocab_size, embed_dim]).astype(np.float32)
|
||||
# 0 is Padding index, thus init it as 0.
|
||||
init_weight[0, :] = 0
|
||||
self.embedding_table = Parameter(Tensor(init_weight),
|
||||
|
|
|
@ -17,13 +17,16 @@ import time
|
|||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from mindspore import context
|
||||
|
||||
from src.dataset import load_dataset
|
||||
from .transformer_for_infer import TransformerInferModel
|
||||
from .transformer_for_train import TransformerTraining
|
||||
from ..utils.load_weights import load_infer_weights
|
||||
|
||||
context.set_context(
|
||||
|
@ -156,3 +159,129 @@ def infer(config):
|
|||
shuffle=False) if config.test_dataset else None
|
||||
prediction = transformer_infer(config, eval_dataset)
|
||||
return prediction
|
||||
|
||||
|
||||
class TransformerInferPPLCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of transformer network infer for PPL.
|
||||
|
||||
Args:
|
||||
config(TransformerConfig): Config.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor], predicted log prob and label lengths.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(TransformerInferPPLCell, self).__init__()
|
||||
self.transformer = TransformerTraining(config, is_training=False, use_one_hot_embeddings=False)
|
||||
self.batch_size = config.batch_size
|
||||
self.vocab_size = config.vocab_size
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(float(1), mstype.float32)
|
||||
self.off_value = Tensor(float(0), mstype.float32)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
self.flat_shape = (config.batch_size * config.seq_length,)
|
||||
self.batch_shape = (config.batch_size, config.seq_length)
|
||||
self.last_idx = (-1,)
|
||||
|
||||
def construct(self,
|
||||
source_ids,
|
||||
source_mask,
|
||||
target_ids,
|
||||
target_mask,
|
||||
label_ids,
|
||||
label_mask):
|
||||
"""Defines the computation performed."""
|
||||
|
||||
predicted_log_probs = self.transformer(source_ids, source_mask, target_ids, target_mask)
|
||||
label_ids = self.reshape(label_ids, self.flat_shape)
|
||||
label_mask = self.cast(label_mask, mstype.float32)
|
||||
one_hot_labels = self.one_hot(label_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
|
||||
label_log_probs = self.reduce_sum(predicted_log_probs * one_hot_labels, self.last_idx)
|
||||
label_log_probs = self.reshape(label_log_probs, self.batch_shape)
|
||||
log_probs = label_log_probs * label_mask
|
||||
lengths = self.reduce_sum(label_mask, self.last_idx)
|
||||
|
||||
return log_probs, lengths
|
||||
|
||||
|
||||
def transformer_infer_ppl(config, dataset):
|
||||
"""
|
||||
Run infer with Transformer for PPL.
|
||||
|
||||
Args:
|
||||
config (TransformerConfig): Config.
|
||||
dataset (Dataset): Dataset.
|
||||
|
||||
Returns:
|
||||
List[Dict], prediction, each example has 4 keys, "source",
|
||||
"target", "log_prob" and "length".
|
||||
"""
|
||||
tfm_infer = TransformerInferPPLCell(config=config)
|
||||
tfm_infer.init_parameters_data()
|
||||
|
||||
parameter_dict = load_checkpoint(config.existed_ckpt)
|
||||
load_param_into_net(tfm_infer, parameter_dict)
|
||||
|
||||
model = Model(tfm_infer)
|
||||
|
||||
log_probs = []
|
||||
lengths = []
|
||||
source_sentences = []
|
||||
target_sentences = []
|
||||
for batch in dataset.create_dict_iterator():
|
||||
source_sentences.append(batch["source_eos_ids"])
|
||||
target_sentences.append(batch["target_eos_ids"])
|
||||
|
||||
source_ids = Tensor(batch["source_eos_ids"], mstype.int32)
|
||||
source_mask = Tensor(batch["source_eos_mask"], mstype.int32)
|
||||
target_ids = Tensor(batch["target_sos_ids"], mstype.int32)
|
||||
target_mask = Tensor(batch["target_sos_mask"], mstype.int32)
|
||||
label_ids = Tensor(batch["target_eos_ids"], mstype.int32)
|
||||
label_mask = Tensor(batch["target_eos_mask"], mstype.int32)
|
||||
|
||||
start_time = time.time()
|
||||
log_prob, length = model.predict(source_ids, source_mask, target_ids, target_mask, label_ids, label_mask)
|
||||
print(f" | Batch size: {config.batch_size}, "
|
||||
f"Time cost: {time.time() - start_time}.")
|
||||
|
||||
log_probs.append(log_prob.asnumpy())
|
||||
lengths.append(length.asnumpy())
|
||||
|
||||
output = []
|
||||
for inputs, ref, log_prob, length in zip(source_sentences,
|
||||
target_sentences,
|
||||
log_probs,
|
||||
lengths):
|
||||
for i in range(config.batch_size):
|
||||
example = {
|
||||
"source": inputs[i].tolist(),
|
||||
"target": ref[i].tolist(),
|
||||
"log_prob": log_prob[i].tolist(),
|
||||
"length": length[i]
|
||||
}
|
||||
output.append(example)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def infer_ppl(config):
|
||||
"""
|
||||
Transformer infer PPL api.
|
||||
|
||||
Args:
|
||||
config (TransformerConfig): Config.
|
||||
|
||||
Returns:
|
||||
list, result with
|
||||
"""
|
||||
eval_dataset = load_dataset(data_files=config.test_dataset,
|
||||
batch_size=config.batch_size,
|
||||
epoch_count=1,
|
||||
sink_mode=config.dataset_sink_mode,
|
||||
shuffle=False) if config.test_dataset else None
|
||||
prediction = transformer_infer_ppl(config, eval_dataset)
|
||||
return prediction
|
||||
|
|
|
@ -20,6 +20,7 @@ from .loss_monitor import LossCallBack
|
|||
from .byte_pair_encoding import bpe_encode
|
||||
from .initializer import zero_weight, one_weight, normal_weight, weight_variable
|
||||
from .rouge_score import rouge
|
||||
from .eval_score import get_score
|
||||
|
||||
__all__ = [
|
||||
"Dictionary",
|
||||
|
@ -31,5 +32,6 @@ __all__ = [
|
|||
"one_weight",
|
||||
"zero_weight",
|
||||
"normal_weight",
|
||||
"weight_variable"
|
||||
"weight_variable",
|
||||
"get_score"
|
||||
]
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Get score by given metric."""
|
||||
from .ppl_score import ngram_ppl
|
||||
from .rouge_score import rouge
|
||||
|
||||
|
||||
def get_ppl_score(result):
|
||||
"""
|
||||
Calculate Perplexity(PPL) score.
|
||||
|
||||
Args:
|
||||
List[Dict], prediction, each example has 4 keys, "source",
|
||||
"target", "log_prob" and "length".
|
||||
|
||||
Returns:
|
||||
Float, ppl score.
|
||||
"""
|
||||
log_probs = []
|
||||
total_length = 0
|
||||
|
||||
for sample in result:
|
||||
log_prob = sample['log_prob']
|
||||
length = sample['length']
|
||||
log_probs.extend(log_prob)
|
||||
total_length += length
|
||||
|
||||
print(f" | log_prob:{log_prob}")
|
||||
print(f" | length:{length}")
|
||||
|
||||
ppl = ngram_ppl(log_probs, total_length, log_softmax=True)
|
||||
print(f" | final PPL={ppl}.")
|
||||
return ppl
|
||||
|
||||
|
||||
def get_rouge_score(result, vocab):
|
||||
"""
|
||||
Calculate ROUGE score.
|
||||
|
||||
Args:
|
||||
List[Dict], prediction, each example has 4 keys, "source",
|
||||
"target", "prediction" and "prediction_prob".
|
||||
Dictionary, dict instance.
|
||||
|
||||
retur:
|
||||
Str, rouge score.
|
||||
"""
|
||||
|
||||
predictions = []
|
||||
targets = []
|
||||
for sample in result:
|
||||
predictions.append(' '.join([vocab[t] for t in sample['prediction']]))
|
||||
targets.append(' '.join([vocab[t] for t in sample['target']]))
|
||||
print(f" | source: {' '.join([vocab[t] for t in sample['source']])}")
|
||||
print(f" | target: {targets[-1]}")
|
||||
|
||||
return rouge(predictions, targets)
|
||||
|
||||
|
||||
def get_score(result, vocab=None, metric='rouge'):
|
||||
"""
|
||||
Get eval score.
|
||||
|
||||
Args:
|
||||
List[Dict], prediction.
|
||||
Dictionary, dict instance.
|
||||
Str, metric function, default is rouge.
|
||||
|
||||
Return:
|
||||
Str, Score.
|
||||
"""
|
||||
score = None
|
||||
if metric == 'rouge':
|
||||
score = get_rouge_score(result, vocab)
|
||||
elif metric == 'ppl':
|
||||
score = get_ppl_score(result)
|
||||
else:
|
||||
print(f" |metric not in (rouge, ppl)")
|
||||
|
||||
return score
|
|
@ -17,10 +17,7 @@ from typing import Union
|
|||
|
||||
import numpy as np
|
||||
|
||||
NINF = -1.0 * 1e9
|
||||
|
||||
|
||||
def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = np.e):
|
||||
def ngram_ppl(prob: Union[np.ndarray, list], length: int, log_softmax=False, index: float = np.e):
|
||||
"""
|
||||
Calculate Perplexity(PPL) score under N-gram language model.
|
||||
|
||||
|
@ -39,7 +36,8 @@ def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = n
|
|||
Returns:
|
||||
float, ppl score.
|
||||
"""
|
||||
eps = 1e-8
|
||||
if not length:
|
||||
return np.inf
|
||||
if not isinstance(prob, (np.ndarray, list)):
|
||||
raise TypeError("`prob` must be type of list or np.ndarray.")
|
||||
if not isinstance(prob, np.ndarray):
|
||||
|
@ -47,18 +45,17 @@ def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = n
|
|||
if prob.shape[0] == 0:
|
||||
raise ValueError("`prob` length must greater than 0.")
|
||||
|
||||
p = 1.0
|
||||
sen_len = 0
|
||||
for t in range(prob.shape[0]):
|
||||
s = prob[t]
|
||||
if s <= NINF:
|
||||
break
|
||||
if log_softmax:
|
||||
s = np.power(index, s)
|
||||
p *= (1 / (s + eps))
|
||||
sen_len += 1
|
||||
print(f'length:{length}, log_prob:{prob}')
|
||||
|
||||
if sen_len == 0:
|
||||
return np.inf
|
||||
if log_softmax:
|
||||
prob = np.sum(prob) / length
|
||||
ppl = 1. / np.power(index, prob)
|
||||
print(f'avg log prob:{prob}')
|
||||
else:
|
||||
p = 1.
|
||||
for i in range(prob.shape[0]):
|
||||
p *= (1. / prob[i])
|
||||
ppl = pow(p, 1 / length)
|
||||
|
||||
return pow(p, 1 / sen_len)
|
||||
print(f'ppl val:{ppl}')
|
||||
return ppl
|
||||
|
|
Loading…
Reference in New Issue