diff --git a/mindspore/lite/examples/quick_start_flclient/build.sh b/mindspore/lite/examples/quick_start_flclient/build.sh
new file mode 100644
index 00000000000..ebf36e41929
--- /dev/null
+++ b/mindspore/lite/examples/quick_start_flclient/build.sh
@@ -0,0 +1,73 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+display_usage()
+{
+ echo -e "\nUsage: build.sh [-r release.tar.gz]\n"
+}
+
+checkopts()
+{
+ TARBALL=""
+ while getopts 'r:' opt
+ do
+ case "${opt}" in
+ r)
+ TARBALL=$OPTARG
+ ;;
+ *)
+ echo "Unknown option ${opt}!"
+ display_usage
+ exit 1
+ esac
+ done
+}
+
+checkopts "$@"
+
+BASEPATH=$(cd "$(dirname $0)" || exit; pwd)
+get_version() {
+ VERSION_MAJOR=$(grep "const int ms_version_major =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
+ VERSION_MINOR=$(grep "const int ms_version_minor =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
+ VERSION_REVISION=$(grep "const int ms_version_revision =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
+ VERSION_STR=${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION}
+}
+get_version
+
+MINDSPORE_FILE_NAME="mindspore-lite-${VERSION_STR}-linux-x64"
+MINDSPORE_FILE="${MINDSPORE_FILE_NAME}.tar.gz"
+MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}/MindSpore/lite/release/linux/${MINDSPORE_FILE}"
+
+rm -rf build/${MINDSPORE_FILE_NAME}
+rm -rf lib
+mkdir -p build/${MINDSPORE_FILE_NAME}
+mkdir -p lib
+
+if [ -n "$TARBALL" ]; then
+ cp ${TARBALL} ${BASEPATH}/build/${MINDSPORE_FILE}
+fi
+
+if [ ! -e ${BASEPATH}/build/${MINDSPORE_FILE} ]; then
+ wget -c -O ${BASEPATH}/build/${MINDSPORE_FILE} --no-check-certificate ${MINDSPORE_LITE_DOWNLOAD_URL}
+fi
+
+tar xzvf ${BASEPATH}/build/${MINDSPORE_FILE} -C ${BASEPATH}/build/${MINDSPORE_FILE_NAME} --strip-components=1
+
+cp -r ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/runtime/lib/* ${BASEPATH}/lib
+cp ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/runtime/third_party/libjpeg-turbo/lib/*so* ${BASEPATH}/lib
+cd ${BASEPATH}/ || exit
+
+mvn package
diff --git a/mindspore/lite/examples/quick_start_flclient/pom.xml b/mindspore/lite/examples/quick_start_flclient/pom.xml
new file mode 100644
index 00000000000..c45238cd233
--- /dev/null
+++ b/mindspore/lite/examples/quick_start_flclient/pom.xml
@@ -0,0 +1,51 @@
+
+
+ 4.0.0
+
+ com.mindspore.flclient.demo
+ quick_start_flclient
+ 1.0
+
+
+ 8
+ 8
+
+
+
+
+
+ com.mindspore.flclient
+ mindspore-lite-flclient
+ 1.0
+ system
+ ${project.basedir}/lib/mindspore-lite-java-flclient.jar
+
+
+
+
+ ${project.name}
+
+
+
+ org.apache.maven.plugins
+ maven-assembly-plugin
+
+
+ jar-with-dependencies
+
+
+
+
+ make-assemble
+ package
+
+ single
+
+
+
+
+
+
+
diff --git a/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/Feature.java b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/Feature.java
new file mode 100644
index 00000000000..2076d1b2305
--- /dev/null
+++ b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/Feature.java
@@ -0,0 +1,47 @@
+/**
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mindspore.flclient.demo.albert;
+
+/**
+ * feature class
+ *
+ * @since v1.0
+ */
+public class Feature {
+ int[] inputIds;
+ int[] inputMasks;
+ int[] tokenIds;
+ int labelIds;
+ int seqLen;
+
+ /**
+ * constructor
+ *
+ * @param inputIds input id
+ * @param inputMasks input masks
+ * @param tokenIds token ids
+ * @param labelIds label ids
+ * @param seqLen seq len
+ */
+ public Feature(int[] inputIds, int[] inputMasks, int[] tokenIds, int labelIds, int seqLen) {
+ this.inputIds = inputIds;
+ this.inputMasks = inputMasks;
+ this.tokenIds = tokenIds;
+ this.labelIds = labelIds;
+ this.seqLen = seqLen;
+ }
+}
diff --git a/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/common/ClassifierAccuracyCallback.java b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/common/ClassifierAccuracyCallback.java
new file mode 100644
index 00000000000..9ccea12ffed
--- /dev/null
+++ b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/common/ClassifierAccuracyCallback.java
@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mindspore.flclient.demo.common;
+
+import com.mindspore.flclient.model.Callback;
+import com.mindspore.flclient.model.CommonUtils;
+import com.mindspore.flclient.model.Status;
+import com.mindspore.lite.LiteSession;
+import com.mindspore.lite.MSTensor;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.logging.Logger;
+
+/**
+ * Defining the Callback calculate classifier model.
+ *
+ * @since v1.0
+ */
+public class ClassifierAccuracyCallback extends Callback {
+ private static final Logger LOGGER = Logger.getLogger(ClassifierAccuracyCallback.class.toString());
+ private final int numOfClass;
+ private final int batchSize;
+ private final List targetLabels;
+ private float accuracy;
+
+ /**
+ * Defining a constructor of ClassifierAccuracyCallback.
+ */
+ public ClassifierAccuracyCallback(LiteSession session, int batchSize, int numOfClass, List targetLabels) {
+ super(session);
+ this.batchSize = batchSize;
+ this.numOfClass = numOfClass;
+ this.targetLabels = targetLabels;
+ }
+
+ /**
+ * Get eval accuracy.
+ *
+ * @return accuracy.
+ */
+ public float getAccuracy() {
+ return accuracy;
+ }
+
+ @Override
+ public Status stepBegin() {
+ return Status.SUCCESS;
+ }
+
+ @Override
+ public Status stepEnd() {
+ Status status = calAccuracy();
+ if (status != Status.SUCCESS) {
+ return status;
+ }
+ steps++;
+ return Status.SUCCESS;
+ }
+
+ @Override
+ public Status epochBegin() {
+ return Status.SUCCESS;
+ }
+
+ @Override
+ public Status epochEnd() {
+ LOGGER.info("average accuracy:" + steps + ",acc is:" + accuracy / steps);
+ accuracy = accuracy / steps;
+ steps = 0;
+ return Status.SUCCESS;
+ }
+
+ private Status calAccuracy() {
+ if (targetLabels == null || targetLabels.isEmpty()) {
+ LOGGER.severe("labels cannot be null");
+ return Status.NULLPTR;
+ }
+ Optional outputTensor = searchOutputsForSize(batchSize * numOfClass);
+ if (!outputTensor.isPresent()) {
+ return Status.NULLPTR;
+ }
+ float[] scores = outputTensor.get().getFloatData();
+ int hitCounts = 0;
+ for (int b = 0; b < batchSize; b++) {
+ int predictIdx = CommonUtils.getMaxScoreIndex(scores, numOfClass * b, numOfClass * b + numOfClass);
+ if (targetLabels.get(b + steps * batchSize) == predictIdx) {
+ hitCounts += 1;
+ }
+ }
+ accuracy += ((float) (hitCounts) / batchSize);
+ LOGGER.info("steps:" + steps + ",acc is:" + (float) (hitCounts) / batchSize);
+ return Status.SUCCESS;
+ }
+
+}
diff --git a/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/common/PredictCallback.java b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/common/PredictCallback.java
new file mode 100644
index 00000000000..2def8822968
--- /dev/null
+++ b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/common/PredictCallback.java
@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mindspore.flclient.demo.common;
+
+import com.mindspore.flclient.model.Callback;
+import com.mindspore.flclient.model.Status;
+import com.mindspore.lite.LiteSession;
+import com.mindspore.lite.MSTensor;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.logging.Logger;
+
+/**
+ * Defining the Callback get model predict result.
+ *
+ * @since v1.0
+ */
+public class PredictCallback extends Callback {
+ private static final Logger LOGGER = Logger.getLogger(PredictCallback.class.toString());
+ private final List predictResults = new ArrayList<>();
+ private final int numOfClass;
+ private final int batchSize;
+
+ /**
+ * Defining a constructor of predict callback.
+ */
+ public PredictCallback(LiteSession session, int batchSize, int numOfClass) {
+ super(session);
+ this.batchSize = batchSize;
+ this.numOfClass = numOfClass;
+ }
+
+ public static int getMaxScoreIndex(float[] scores, int start, int end) {
+ if (scores != null && scores.length != 0) {
+ if (start < scores.length && start >= 0 && end <= scores.length && end >= 0) {
+ float maxScore = scores[start];
+ int maxIdx = start;
+
+ for (int i = start; i < end; ++i) {
+ if (scores[i] > maxScore) {
+ maxIdx = i;
+ maxScore = scores[i];
+ }
+ }
+
+ return maxIdx - start;
+ } else {
+ LOGGER.severe("start,end cannot out of scores length");
+ return -1;
+ }
+ } else {
+ LOGGER.severe("scores cannot be empty");
+ return -1;
+ }
+ }
+
+ /**
+ * Get predict results.
+ *
+ * @return predict result.
+ */
+ public List getPredictResults() {
+ return predictResults;
+ }
+
+ @Override
+ public Status stepBegin() {
+ return Status.SUCCESS;
+ }
+
+ @Override
+ public Status stepEnd() {
+ Optional outputTensor = searchOutputsForSize(batchSize * numOfClass);
+ if (!outputTensor.isPresent()) {
+ return Status.FAILED;
+ }
+ float[] scores = outputTensor.get().getFloatData();
+ for (int b = 0; b < batchSize; b++) {
+ int predictIdx = getMaxScoreIndex(scores, numOfClass * b, numOfClass * b + numOfClass);
+ predictResults.add(predictIdx);
+ }
+ return Status.SUCCESS;
+ }
+
+ @Override
+ public Status epochBegin() {
+ return Status.SUCCESS;
+ }
+
+ @Override
+ public Status epochEnd() {
+ return Status.SUCCESS;
+ }
+}
diff --git a/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java
new file mode 100644
index 00000000000..d42a549fc5d
--- /dev/null
+++ b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java
@@ -0,0 +1,114 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mindspore.flclient.demo.lenet;
+
+import com.mindspore.flclient.demo.common.ClassifierAccuracyCallback;
+import com.mindspore.flclient.demo.common.PredictCallback;
+import com.mindspore.flclient.model.Callback;
+import com.mindspore.flclient.model.Client;
+import com.mindspore.flclient.model.ClientManager;
+import com.mindspore.flclient.model.DataSet;
+import com.mindspore.flclient.model.LossCallback;
+import com.mindspore.flclient.model.RunType;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.logging.Logger;
+
+/**
+ * Defining the lenet client base class.
+ *
+ * @since v1.0
+ */
+public class LenetClient extends Client {
+ private static final Logger LOGGER = Logger.getLogger(LenetClient.class.toString());
+ private static final int NUM_OF_CLASS = 62;
+
+ static {
+ ClientManager.registerClient(new LenetClient());
+ }
+
+ @Override
+ public List initCallbacks(RunType runType, DataSet dataSet) {
+ List callbacks = new ArrayList<>();
+ if (runType == RunType.TRAINMODE) {
+ Callback lossCallback = new LossCallback(trainSession);
+ callbacks.add(lossCallback);
+ } else if (runType == RunType.EVALMODE) {
+ if (dataSet instanceof LenetDataSet) {
+ Callback evalCallback = new ClassifierAccuracyCallback(trainSession, dataSet.batchSize, NUM_OF_CLASS,
+ ((LenetDataSet) dataSet).getTargetLabels());
+ callbacks.add(evalCallback);
+ }
+ } else {
+ Callback inferCallback = new PredictCallback(trainSession, dataSet.batchSize, NUM_OF_CLASS);
+ callbacks.add(inferCallback);
+ }
+ return callbacks;
+ }
+
+ @Override
+ public Map initDataSets(Map> files) {
+ Map sampleCounts = new HashMap<>();
+ List trainFiles = files.getOrDefault(RunType.TRAINMODE, null);
+ if (trainFiles != null) {
+ DataSet trainDataSet = new LenetDataSet(NUM_OF_CLASS);
+ trainDataSet.init(trainFiles);
+ dataSets.put(RunType.TRAINMODE, trainDataSet);
+ sampleCounts.put(RunType.TRAINMODE, trainDataSet.sampleSize);
+ }
+ List evalFiles = files.getOrDefault(RunType.EVALMODE, null);
+ if (evalFiles != null) {
+ LenetDataSet evalDataSet = new LenetDataSet(NUM_OF_CLASS);
+ evalDataSet.init(evalFiles);
+ dataSets.put(RunType.EVALMODE, evalDataSet);
+ sampleCounts.put(RunType.EVALMODE, evalDataSet.sampleSize);
+ }
+ List inferFiles = files.getOrDefault(RunType.INFERMODE, null);
+ if (inferFiles != null) {
+ DataSet inferDataSet = new LenetDataSet(NUM_OF_CLASS);
+ inferDataSet.init(inferFiles);
+ dataSets.put(RunType.INFERMODE, inferDataSet);
+ sampleCounts.put(RunType.INFERMODE, inferDataSet.sampleSize);
+ }
+ return sampleCounts;
+ }
+
+ @Override
+ public float getEvalAccuracy(List evalCallbacks) {
+ for (Callback callBack : evalCallbacks) {
+ if (callBack instanceof ClassifierAccuracyCallback) {
+ return ((ClassifierAccuracyCallback) callBack).getAccuracy();
+ }
+ }
+ LOGGER.severe("don not find accuracy related callback");
+ return Float.NaN;
+ }
+
+ @Override
+ public List getInferResult(List inferCallbacks) {
+ for (Callback callBack : inferCallbacks) {
+ if (callBack instanceof PredictCallback) {
+ return ((PredictCallback) callBack).getPredictResults();
+ }
+ }
+ LOGGER.severe("don not find accuracy related callback");
+ return new ArrayList<>();
+ }
+}
diff --git a/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetDataSet.java b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetDataSet.java
new file mode 100644
index 00000000000..e286febee67
--- /dev/null
+++ b/mindspore/lite/examples/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetDataSet.java
@@ -0,0 +1,201 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mindspore.flclient.demo.lenet;
+
+import com.mindspore.flclient.model.DataSet;
+import com.mindspore.flclient.model.Status;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.logging.Logger;
+
+/**
+ * Defining the minist dataset for lenet.
+ *
+ * @since v1.0
+ */
+public class LenetDataSet extends DataSet {
+ private static final Logger LOGGER = Logger.getLogger(LenetDataSet.class.toString());
+ private static final int IMAGE_SIZE = 32 * 32 * 3;
+ private static final int FLOAT_BYTE_SIZE = 4;
+ private byte[] imageArray;
+ private int[] labelArray;
+ private final int numOfClass;
+ private List targetLabels;
+
+ /**
+ * Defining a constructor of lenet dataset.
+ */
+ public LenetDataSet(int numOfClass) {
+ this.numOfClass = numOfClass;
+ }
+
+ /**
+ * Get dataset labels.
+ *
+ * @return dataset target labels.
+ */
+ public List getTargetLabels() {
+ return targetLabels;
+ }
+
+ @Override
+ public void fillInputBuffer(List inputsBuffer, int batchIdx) {
+ // infer,train,eval model is same one
+ if (inputsBuffer.size() != 2) {
+ LOGGER.severe("input size error");
+ return;
+ }
+ if (batchIdx > batchNum) {
+ LOGGER.severe("fill model image input failed");
+ return;
+ }
+ for (ByteBuffer inputBuffer : inputsBuffer) {
+ inputBuffer.clear();
+ }
+ ByteBuffer imageBuffer = inputsBuffer.get(0);
+ ByteBuffer labelIdBuffer = inputsBuffer.get(1);
+ int imageInputBytes = IMAGE_SIZE * batchSize * Float.BYTES;
+ for (int i = 0; i < imageInputBytes; i++) {
+ imageBuffer.put(imageArray[batchIdx * imageInputBytes + i]);
+ }
+ if (labelArray == null) {
+ return;
+ }
+ int labelSize = batchSize * numOfClass;
+ if ((batchIdx + 1) * labelSize - 1 >= labelArray.length) {
+ LOGGER.severe("fill model label input failed");
+ return;
+ }
+ labelIdBuffer.clear();
+ for (int i = 0; i < labelSize; i++) {
+ labelIdBuffer.putFloat(labelArray[batchIdx * labelSize + i]);
+ }
+ }
+
+ @Override
+ public void shuffle() {
+
+ }
+
+ @Override
+ public void padding() {
+ if (labelArray == null) // infer model
+ {
+ labelArray = new int[imageArray.length * numOfClass / (IMAGE_SIZE * Float.BYTES)];
+ Arrays.fill(labelArray, 0);
+ }
+ int curSize = labelArray.length / numOfClass;
+ int modSize = curSize - curSize / batchSize * batchSize;
+ int padSize = modSize != 0 ? batchSize * numOfClass - modSize : 0;
+ if (padSize != 0) {
+ int[] padLabelArray = new int[labelArray.length + padSize * numOfClass];
+ byte[] padImageArray = new byte[imageArray.length + padSize * IMAGE_SIZE * Float.BYTES];
+ System.arraycopy(labelArray, 0, padLabelArray, 0, labelArray.length);
+ System.arraycopy(imageArray, 0, padImageArray, 0, imageArray.length);
+ for (int i = 0; i < padSize; i++) {
+ int idx = (int) (Math.random() * curSize);
+ System.arraycopy(labelArray, idx * numOfClass, padLabelArray, labelArray.length + i * numOfClass,
+ numOfClass);
+ System.arraycopy(imageArray, idx * IMAGE_SIZE * Float.BYTES, padImageArray,
+ padImageArray.length + i * IMAGE_SIZE * Float.BYTES, IMAGE_SIZE * Float.BYTES);
+ }
+ labelArray = padLabelArray;
+ imageArray = padImageArray;
+ }
+ sampleSize = curSize + padSize;
+ batchNum = sampleSize / batchSize;
+ setPredictLabels(labelArray);
+ LOGGER.info("total samples:" + sampleSize);
+ LOGGER.info("total batchNum:" + batchNum);
+ }
+
+ private void setPredictLabels(int[] labelArray) {
+ int labels_num = labelArray.length / numOfClass;
+ targetLabels = new ArrayList<>(labels_num);
+ for (int i = 0; i < labels_num; i++) {
+ int label = getMaxIndex(labelArray, numOfClass * i, numOfClass * (i + 1));
+ if (label == -1) {
+ LOGGER.severe("get max index failed");
+ }
+ targetLabels.add(label);
+ }
+ }
+
+ private int getMaxIndex(int[] nums, int begin, int end) {
+ for (int i = begin; i < end; i++) {
+ if (nums[i] == 1) {
+ return i - begin;
+ }
+ }
+ return -1;
+ }
+
+ public static byte[] readBinFile(String dataFile) {
+ if (dataFile == null || dataFile.isEmpty()) {
+ LOGGER.severe("file cannot be empty");
+ return new byte[0];
+ }
+ // read train file
+ Path path = Paths.get(dataFile);
+ byte[] data = new byte[0];
+ try {
+ data = Files.readAllBytes(path);
+ } catch (IOException e) {
+ LOGGER.severe("read data file failed,please check data file path");
+ }
+ return data;
+ }
+
+ @Override
+ public Status dataPreprocess(List files) {
+ String labelFile = "";
+ String imageFile;
+ if (files.size() == 2) {
+ imageFile = files.get(0);
+ labelFile = files.get(1);
+ } else if (files.size() == 1) {
+ imageFile = files.get(0);
+ } else {
+ LOGGER.severe("files size error");
+ return Status.FAILED;
+ }
+ imageArray = readBinFile(imageFile);
+ if (labelFile != null && !labelFile.isEmpty()) {
+ byte[] labelByteArray = readBinFile(labelFile);
+ targetLabels = new ArrayList<>(labelByteArray.length / FLOAT_BYTE_SIZE);
+ // model labels use one hot
+ labelArray = new int[labelByteArray.length / FLOAT_BYTE_SIZE * numOfClass];
+ Arrays.fill(labelArray, 0);
+ int offset = 0;
+ for (int i = 0; i < labelByteArray.length; i += FLOAT_BYTE_SIZE) {
+ labelArray[offset * numOfClass + labelByteArray[i]] = 1;
+ offset++;
+ }
+ } else {
+ labelArray = null; // labelArray may be initialized from train
+ }
+ sampleSize = imageArray.length / IMAGE_SIZE / FLOAT_BYTE_SIZE;
+ return Status.SUCCESS;
+ }
+}
diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/SessionUtil.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/SessionUtil.java
index 7f20adad835..29fb9d15a96 100644
--- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/SessionUtil.java
+++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/SessionUtil.java
@@ -152,6 +152,10 @@ public class SessionUtil {
logger.severe(Common.addTag("trainSession cannot be null"));
return;
}
+ if(trainSession.getSessionPtr() == 0) {
+ logger.warning(Common.addTag("trainSession pointer has already free"));
+ return;
+ }
trainSession.free();
}
}