diff --git a/mindspore/lite/examples/quick_start_flclient/build.sh b/mindspore/lite/examples/quick_start_flclient/build.sh new file mode 100644 index 00000000000..ebf36e41929 --- /dev/null +++ b/mindspore/lite/examples/quick_start_flclient/build.sh @@ -0,0 +1,73 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +display_usage() +{ + echo -e "\nUsage: build.sh [-r release.tar.gz]\n" +} + +checkopts() +{ + TARBALL="" + while getopts 'r:' opt + do + case "${opt}" in + r) + TARBALL=$OPTARG + ;; + *) + echo "Unknown option ${opt}!" + display_usage + exit 1 + esac + done +} + +checkopts "$@" + +BASEPATH=$(cd "$(dirname $0)" || exit; pwd) +get_version() { + VERSION_MAJOR=$(grep "const int ms_version_major =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]") + VERSION_MINOR=$(grep "const int ms_version_minor =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]") + VERSION_REVISION=$(grep "const int ms_version_revision =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]") + VERSION_STR=${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION} +} +get_version + +MINDSPORE_FILE_NAME="mindspore-lite-${VERSION_STR}-linux-x64" +MINDSPORE_FILE="${MINDSPORE_FILE_NAME}.tar.gz" +MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}/MindSpore/lite/release/linux/${MINDSPORE_FILE}" + +rm -rf build/${MINDSPORE_FILE_NAME} +rm -rf lib +mkdir -p build/${MINDSPORE_FILE_NAME} +mkdir -p lib + +if [ -n "$TARBALL" ]; then + cp ${TARBALL} ${BASEPATH}/build/${MINDSPORE_FILE} +fi + +if [ ! -e ${BASEPATH}/build/${MINDSPORE_FILE} ]; then + wget -c -O ${BASEPATH}/build/${MINDSPORE_FILE} --no-check-certificate ${MINDSPORE_LITE_DOWNLOAD_URL} +fi + +tar xzvf ${BASEPATH}/build/${MINDSPORE_FILE} -C ${BASEPATH}/build/${MINDSPORE_FILE_NAME} --strip-components=1 + +cp -r ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/runtime/lib/* ${BASEPATH}/lib +cp ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/runtime/third_party/libjpeg-turbo/lib/*so* ${BASEPATH}/lib +cd ${BASEPATH}/ || exit + +mvn package diff --git a/mindspore/lite/examples/quick_start_flclient/pom.xml b/mindspore/lite/examples/quick_start_flclient/pom.xml new file mode 100644 index 00000000000..c45238cd233 --- /dev/null +++ b/mindspore/lite/examples/quick_start_flclient/pom.xml @@ -0,0 +1,51 @@ + + + 4.0.0 + + com.mindspore.flclient.demo + quick_start_flclient + 1.0 + + + 8 + 8 + + + + + + com.mindspore.flclient + mindspore-lite-flclient + 1.0 + system + ${project.basedir}/lib/mindspore-lite-java-flclient.jar + + + + + ${project.name} + + + + org.apache.maven.plugins + maven-assembly-plugin + + + jar-with-dependencies + + + + + make-assemble + package + + single + + + + + + + diff --git a/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/Feature.java b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/Feature.java new file mode 100644 index 00000000000..2076d1b2305 --- /dev/null +++ b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/Feature.java @@ -0,0 +1,47 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.flclient.demo.albert; + +/** + * feature class + * + * @since v1.0 + */ +public class Feature { + int[] inputIds; + int[] inputMasks; + int[] tokenIds; + int labelIds; + int seqLen; + + /** + * constructor + * + * @param inputIds input id + * @param inputMasks input masks + * @param tokenIds token ids + * @param labelIds label ids + * @param seqLen seq len + */ + public Feature(int[] inputIds, int[] inputMasks, int[] tokenIds, int labelIds, int seqLen) { + this.inputIds = inputIds; + this.inputMasks = inputMasks; + this.tokenIds = tokenIds; + this.labelIds = labelIds; + this.seqLen = seqLen; + } +} diff --git a/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/common/ClassifierAccuracyCallback.java b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/common/ClassifierAccuracyCallback.java new file mode 100644 index 00000000000..9ccea12ffed --- /dev/null +++ b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/common/ClassifierAccuracyCallback.java @@ -0,0 +1,110 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.flclient.demo.common; + +import com.mindspore.flclient.model.Callback; +import com.mindspore.flclient.model.CommonUtils; +import com.mindspore.flclient.model.Status; +import com.mindspore.lite.LiteSession; +import com.mindspore.lite.MSTensor; + +import java.util.List; +import java.util.Optional; +import java.util.logging.Logger; + +/** + * Defining the Callback calculate classifier model. + * + * @since v1.0 + */ +public class ClassifierAccuracyCallback extends Callback { + private static final Logger LOGGER = Logger.getLogger(ClassifierAccuracyCallback.class.toString()); + private final int numOfClass; + private final int batchSize; + private final List targetLabels; + private float accuracy; + + /** + * Defining a constructor of ClassifierAccuracyCallback. + */ + public ClassifierAccuracyCallback(LiteSession session, int batchSize, int numOfClass, List targetLabels) { + super(session); + this.batchSize = batchSize; + this.numOfClass = numOfClass; + this.targetLabels = targetLabels; + } + + /** + * Get eval accuracy. + * + * @return accuracy. + */ + public float getAccuracy() { + return accuracy; + } + + @Override + public Status stepBegin() { + return Status.SUCCESS; + } + + @Override + public Status stepEnd() { + Status status = calAccuracy(); + if (status != Status.SUCCESS) { + return status; + } + steps++; + return Status.SUCCESS; + } + + @Override + public Status epochBegin() { + return Status.SUCCESS; + } + + @Override + public Status epochEnd() { + LOGGER.info("average accuracy:" + steps + ",acc is:" + accuracy / steps); + accuracy = accuracy / steps; + steps = 0; + return Status.SUCCESS; + } + + private Status calAccuracy() { + if (targetLabels == null || targetLabels.isEmpty()) { + LOGGER.severe("labels cannot be null"); + return Status.NULLPTR; + } + Optional outputTensor = searchOutputsForSize(batchSize * numOfClass); + if (!outputTensor.isPresent()) { + return Status.NULLPTR; + } + float[] scores = outputTensor.get().getFloatData(); + int hitCounts = 0; + for (int b = 0; b < batchSize; b++) { + int predictIdx = CommonUtils.getMaxScoreIndex(scores, numOfClass * b, numOfClass * b + numOfClass); + if (targetLabels.get(b + steps * batchSize) == predictIdx) { + hitCounts += 1; + } + } + accuracy += ((float) (hitCounts) / batchSize); + LOGGER.info("steps:" + steps + ",acc is:" + (float) (hitCounts) / batchSize); + return Status.SUCCESS; + } + +} diff --git a/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/common/PredictCallback.java b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/common/PredictCallback.java new file mode 100644 index 00000000000..2def8822968 --- /dev/null +++ b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/common/PredictCallback.java @@ -0,0 +1,110 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.flclient.demo.common; + +import com.mindspore.flclient.model.Callback; +import com.mindspore.flclient.model.Status; +import com.mindspore.lite.LiteSession; +import com.mindspore.lite.MSTensor; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.logging.Logger; + +/** + * Defining the Callback get model predict result. + * + * @since v1.0 + */ +public class PredictCallback extends Callback { + private static final Logger LOGGER = Logger.getLogger(PredictCallback.class.toString()); + private final List predictResults = new ArrayList<>(); + private final int numOfClass; + private final int batchSize; + + /** + * Defining a constructor of predict callback. + */ + public PredictCallback(LiteSession session, int batchSize, int numOfClass) { + super(session); + this.batchSize = batchSize; + this.numOfClass = numOfClass; + } + + public static int getMaxScoreIndex(float[] scores, int start, int end) { + if (scores != null && scores.length != 0) { + if (start < scores.length && start >= 0 && end <= scores.length && end >= 0) { + float maxScore = scores[start]; + int maxIdx = start; + + for (int i = start; i < end; ++i) { + if (scores[i] > maxScore) { + maxIdx = i; + maxScore = scores[i]; + } + } + + return maxIdx - start; + } else { + LOGGER.severe("start,end cannot out of scores length"); + return -1; + } + } else { + LOGGER.severe("scores cannot be empty"); + return -1; + } + } + + /** + * Get predict results. + * + * @return predict result. + */ + public List getPredictResults() { + return predictResults; + } + + @Override + public Status stepBegin() { + return Status.SUCCESS; + } + + @Override + public Status stepEnd() { + Optional outputTensor = searchOutputsForSize(batchSize * numOfClass); + if (!outputTensor.isPresent()) { + return Status.FAILED; + } + float[] scores = outputTensor.get().getFloatData(); + for (int b = 0; b < batchSize; b++) { + int predictIdx = getMaxScoreIndex(scores, numOfClass * b, numOfClass * b + numOfClass); + predictResults.add(predictIdx); + } + return Status.SUCCESS; + } + + @Override + public Status epochBegin() { + return Status.SUCCESS; + } + + @Override + public Status epochEnd() { + return Status.SUCCESS; + } +} diff --git a/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java new file mode 100644 index 00000000000..d42a549fc5d --- /dev/null +++ b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java @@ -0,0 +1,114 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.flclient.demo.lenet; + +import com.mindspore.flclient.demo.common.ClassifierAccuracyCallback; +import com.mindspore.flclient.demo.common.PredictCallback; +import com.mindspore.flclient.model.Callback; +import com.mindspore.flclient.model.Client; +import com.mindspore.flclient.model.ClientManager; +import com.mindspore.flclient.model.DataSet; +import com.mindspore.flclient.model.LossCallback; +import com.mindspore.flclient.model.RunType; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.logging.Logger; + +/** + * Defining the lenet client base class. + * + * @since v1.0 + */ +public class LenetClient extends Client { + private static final Logger LOGGER = Logger.getLogger(LenetClient.class.toString()); + private static final int NUM_OF_CLASS = 62; + + static { + ClientManager.registerClient(new LenetClient()); + } + + @Override + public List initCallbacks(RunType runType, DataSet dataSet) { + List callbacks = new ArrayList<>(); + if (runType == RunType.TRAINMODE) { + Callback lossCallback = new LossCallback(trainSession); + callbacks.add(lossCallback); + } else if (runType == RunType.EVALMODE) { + if (dataSet instanceof LenetDataSet) { + Callback evalCallback = new ClassifierAccuracyCallback(trainSession, dataSet.batchSize, NUM_OF_CLASS, + ((LenetDataSet) dataSet).getTargetLabels()); + callbacks.add(evalCallback); + } + } else { + Callback inferCallback = new PredictCallback(trainSession, dataSet.batchSize, NUM_OF_CLASS); + callbacks.add(inferCallback); + } + return callbacks; + } + + @Override + public Map initDataSets(Map> files) { + Map sampleCounts = new HashMap<>(); + List trainFiles = files.getOrDefault(RunType.TRAINMODE, null); + if (trainFiles != null) { + DataSet trainDataSet = new LenetDataSet(NUM_OF_CLASS); + trainDataSet.init(trainFiles); + dataSets.put(RunType.TRAINMODE, trainDataSet); + sampleCounts.put(RunType.TRAINMODE, trainDataSet.sampleSize); + } + List evalFiles = files.getOrDefault(RunType.EVALMODE, null); + if (evalFiles != null) { + LenetDataSet evalDataSet = new LenetDataSet(NUM_OF_CLASS); + evalDataSet.init(evalFiles); + dataSets.put(RunType.EVALMODE, evalDataSet); + sampleCounts.put(RunType.EVALMODE, evalDataSet.sampleSize); + } + List inferFiles = files.getOrDefault(RunType.INFERMODE, null); + if (inferFiles != null) { + DataSet inferDataSet = new LenetDataSet(NUM_OF_CLASS); + inferDataSet.init(inferFiles); + dataSets.put(RunType.INFERMODE, inferDataSet); + sampleCounts.put(RunType.INFERMODE, inferDataSet.sampleSize); + } + return sampleCounts; + } + + @Override + public float getEvalAccuracy(List evalCallbacks) { + for (Callback callBack : evalCallbacks) { + if (callBack instanceof ClassifierAccuracyCallback) { + return ((ClassifierAccuracyCallback) callBack).getAccuracy(); + } + } + LOGGER.severe("don not find accuracy related callback"); + return Float.NaN; + } + + @Override + public List getInferResult(List inferCallbacks) { + for (Callback callBack : inferCallbacks) { + if (callBack instanceof PredictCallback) { + return ((PredictCallback) callBack).getPredictResults(); + } + } + LOGGER.severe("don not find accuracy related callback"); + return new ArrayList<>(); + } +} diff --git a/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetDataSet.java b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetDataSet.java new file mode 100644 index 00000000000..e286febee67 --- /dev/null +++ b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetDataSet.java @@ -0,0 +1,201 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.flclient.demo.lenet; + +import com.mindspore.flclient.model.DataSet; +import com.mindspore.flclient.model.Status; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.logging.Logger; + +/** + * Defining the minist dataset for lenet. + * + * @since v1.0 + */ +public class LenetDataSet extends DataSet { + private static final Logger LOGGER = Logger.getLogger(LenetDataSet.class.toString()); + private static final int IMAGE_SIZE = 32 * 32 * 3; + private static final int FLOAT_BYTE_SIZE = 4; + private byte[] imageArray; + private int[] labelArray; + private final int numOfClass; + private List targetLabels; + + /** + * Defining a constructor of lenet dataset. + */ + public LenetDataSet(int numOfClass) { + this.numOfClass = numOfClass; + } + + /** + * Get dataset labels. + * + * @return dataset target labels. + */ + public List getTargetLabels() { + return targetLabels; + } + + @Override + public void fillInputBuffer(List inputsBuffer, int batchIdx) { + // infer,train,eval model is same one + if (inputsBuffer.size() != 2) { + LOGGER.severe("input size error"); + return; + } + if (batchIdx > batchNum) { + LOGGER.severe("fill model image input failed"); + return; + } + for (ByteBuffer inputBuffer : inputsBuffer) { + inputBuffer.clear(); + } + ByteBuffer imageBuffer = inputsBuffer.get(0); + ByteBuffer labelIdBuffer = inputsBuffer.get(1); + int imageInputBytes = IMAGE_SIZE * batchSize * Float.BYTES; + for (int i = 0; i < imageInputBytes; i++) { + imageBuffer.put(imageArray[batchIdx * imageInputBytes + i]); + } + if (labelArray == null) { + return; + } + int labelSize = batchSize * numOfClass; + if ((batchIdx + 1) * labelSize - 1 >= labelArray.length) { + LOGGER.severe("fill model label input failed"); + return; + } + labelIdBuffer.clear(); + for (int i = 0; i < labelSize; i++) { + labelIdBuffer.putFloat(labelArray[batchIdx * labelSize + i]); + } + } + + @Override + public void shuffle() { + + } + + @Override + public void padding() { + if (labelArray == null) // infer model + { + labelArray = new int[imageArray.length * numOfClass / (IMAGE_SIZE * Float.BYTES)]; + Arrays.fill(labelArray, 0); + } + int curSize = labelArray.length / numOfClass; + int modSize = curSize - curSize / batchSize * batchSize; + int padSize = modSize != 0 ? batchSize * numOfClass - modSize : 0; + if (padSize != 0) { + int[] padLabelArray = new int[labelArray.length + padSize * numOfClass]; + byte[] padImageArray = new byte[imageArray.length + padSize * IMAGE_SIZE * Float.BYTES]; + System.arraycopy(labelArray, 0, padLabelArray, 0, labelArray.length); + System.arraycopy(imageArray, 0, padImageArray, 0, imageArray.length); + for (int i = 0; i < padSize; i++) { + int idx = (int) (Math.random() * curSize); + System.arraycopy(labelArray, idx * numOfClass, padLabelArray, labelArray.length + i * numOfClass, + numOfClass); + System.arraycopy(imageArray, idx * IMAGE_SIZE * Float.BYTES, padImageArray, + padImageArray.length + i * IMAGE_SIZE * Float.BYTES, IMAGE_SIZE * Float.BYTES); + } + labelArray = padLabelArray; + imageArray = padImageArray; + } + sampleSize = curSize + padSize; + batchNum = sampleSize / batchSize; + setPredictLabels(labelArray); + LOGGER.info("total samples:" + sampleSize); + LOGGER.info("total batchNum:" + batchNum); + } + + private void setPredictLabels(int[] labelArray) { + int labels_num = labelArray.length / numOfClass; + targetLabels = new ArrayList<>(labels_num); + for (int i = 0; i < labels_num; i++) { + int label = getMaxIndex(labelArray, numOfClass * i, numOfClass * (i + 1)); + if (label == -1) { + LOGGER.severe("get max index failed"); + } + targetLabels.add(label); + } + } + + private int getMaxIndex(int[] nums, int begin, int end) { + for (int i = begin; i < end; i++) { + if (nums[i] == 1) { + return i - begin; + } + } + return -1; + } + + public static byte[] readBinFile(String dataFile) { + if (dataFile == null || dataFile.isEmpty()) { + LOGGER.severe("file cannot be empty"); + return new byte[0]; + } + // read train file + Path path = Paths.get(dataFile); + byte[] data = new byte[0]; + try { + data = Files.readAllBytes(path); + } catch (IOException e) { + LOGGER.severe("read data file failed,please check data file path"); + } + return data; + } + + @Override + public Status dataPreprocess(List files) { + String labelFile = ""; + String imageFile; + if (files.size() == 2) { + imageFile = files.get(0); + labelFile = files.get(1); + } else if (files.size() == 1) { + imageFile = files.get(0); + } else { + LOGGER.severe("files size error"); + return Status.FAILED; + } + imageArray = readBinFile(imageFile); + if (labelFile != null && !labelFile.isEmpty()) { + byte[] labelByteArray = readBinFile(labelFile); + targetLabels = new ArrayList<>(labelByteArray.length / FLOAT_BYTE_SIZE); + // model labels use one hot + labelArray = new int[labelByteArray.length / FLOAT_BYTE_SIZE * numOfClass]; + Arrays.fill(labelArray, 0); + int offset = 0; + for (int i = 0; i < labelByteArray.length; i += FLOAT_BYTE_SIZE) { + labelArray[offset * numOfClass + labelByteArray[i]] = 1; + offset++; + } + } else { + labelArray = null; // labelArray may be initialized from train + } + sampleSize = imageArray.length / IMAGE_SIZE / FLOAT_BYTE_SIZE; + return Status.SUCCESS; + } +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/SessionUtil.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/SessionUtil.java index 7f20adad835..29fb9d15a96 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/SessionUtil.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/SessionUtil.java @@ -152,6 +152,10 @@ public class SessionUtil { logger.severe(Common.addTag("trainSession cannot be null")); return; } + if(trainSession.getSessionPtr() == 0) { + logger.warning(Common.addTag("trainSession pointer has already free")); + return; + } trainSession.free(); } }