forked from mindspore-Ecosystem/mindspore
Fix full batch error
This commit is contained in:
parent
84957cc4a7
commit
640f7194b9
|
@ -18,9 +18,7 @@ Area under cure metric
|
|||
"""
|
||||
|
||||
from sklearn.metrics import roc_auc_score
|
||||
from mindspore import context
|
||||
from mindspore.nn.metrics import Metric
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
|
||||
class AUCMetric(Metric):
|
||||
"""
|
||||
|
@ -30,7 +28,6 @@ class AUCMetric(Metric):
|
|||
def __init__(self):
|
||||
super(AUCMetric, self).__init__()
|
||||
self.clear()
|
||||
self.full_batch = context.get_auto_parallel_context("full_batch")
|
||||
|
||||
def clear(self):
|
||||
"""Clear the internal evaluation result."""
|
||||
|
@ -42,13 +39,7 @@ class AUCMetric(Metric):
|
|||
all_predict = inputs[1].asnumpy().flatten().tolist() # predict
|
||||
all_label = inputs[2].asnumpy().flatten().tolist() # label
|
||||
self.pred_probs.extend(all_predict)
|
||||
if self.full_batch:
|
||||
rank_id = get_rank()
|
||||
group_size = get_group_size()
|
||||
gap = len(all_label) // group_size
|
||||
self.true_labels.extend(all_label[rank_id*gap: (rank_id+1)*gap])
|
||||
else:
|
||||
self.true_labels.extend(all_label)
|
||||
self.true_labels.extend(all_label)
|
||||
|
||||
def eval(self):
|
||||
if len(self.true_labels) != len(self.pred_probs):
|
||||
|
|
Loading…
Reference in New Issue