fix acc bug of albert for fl_client in r1.3

This commit is contained in:
zhoushan 2021-07-08 15:37:27 +08:00
parent 48f0daa9a9
commit 2488333828
1 changed files with 24 additions and 9 deletions

View File

@ -430,17 +430,32 @@ public class FLLiteClient {
retCode = ResponseCode.SUCCEED;
LOGGER.info(Common.addTag("===================================evaluate model after getting model from server==================================="));
if (flParameter.getFlName().equals(ALBERT)) {
AlInferBert alInferBert = AlInferBert.getInstance();
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true);
if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the return dataSize<=0"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
float acc = 0;
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
AlInferBert alInferBert = AlInferBert.getInstance();
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true);
if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alInferBert.initDataSet>: the return dataSize<=0"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
acc = alInferBert.evalModel();
} else {
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile());
if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the return dataSize<=0"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
acc = alTrainBert.evalModel();
}
float acc = alInferBert.evalModel();
if (acc == Float.NaN) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.evalModel>: the return acc is NAN"));
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <evalModel>: the return acc is NAN"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;