diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java index edcc2391a76..edd239158a7 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java @@ -430,17 +430,32 @@ public class FLLiteClient { retCode = ResponseCode.SUCCEED; LOGGER.info(Common.addTag("===================================evaluate model after getting model from server===================================")); if (flParameter.getFlName().equals(ALBERT)) { - AlInferBert alInferBert = AlInferBert.getInstance(); - int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true); - if (dataSize <= 0) { - LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return dataSize<=0")); - status = FLClientStatus.FAILED; - retCode = ResponseCode.RequestError; - return status; + float acc = 0; + if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { + LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod())); + AlInferBert alInferBert = AlInferBert.getInstance(); + int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true); + if (dataSize <= 0) { + LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return dataSize<=0")); + status = FLClientStatus.FAILED; + retCode = ResponseCode.RequestError; + return status; + } + acc = alInferBert.evalModel(); + } else { + LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod())); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile()); + if (dataSize <= 0) { + LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return dataSize<=0")); + status = FLClientStatus.FAILED; + retCode = ResponseCode.RequestError; + return status; + } + acc = alTrainBert.evalModel(); } - float acc = alInferBert.evalModel(); if (acc == Float.NaN) { - LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return acc is NAN")); + LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return acc is NAN")); status = FLClientStatus.FAILED; retCode = ResponseCode.RequestError; return status;