add flclient lenet example

This commit is contained in:
zhengjun10 2021-12-16 15:49:49 +08:00
parent 901d81d16d
commit b9e716ab7a
8 changed files with 710 additions and 0 deletions

View File

@ -0,0 +1,73 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
display_usage()
{
echo -e "\nUsage: build.sh [-r release.tar.gz]\n"
}
checkopts()
{
TARBALL=""
while getopts 'r:' opt
do
case "${opt}" in
r)
TARBALL=$OPTARG
;;
*)
echo "Unknown option ${opt}!"
display_usage
exit 1
esac
done
}
checkopts "$@"
BASEPATH=$(cd "$(dirname $0)" || exit; pwd)
get_version() {
VERSION_MAJOR=$(grep "const int ms_version_major =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
VERSION_MINOR=$(grep "const int ms_version_minor =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
VERSION_REVISION=$(grep "const int ms_version_revision =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
VERSION_STR=${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION}
}
get_version
MINDSPORE_FILE_NAME="mindspore-lite-${VERSION_STR}-linux-x64"
MINDSPORE_FILE="${MINDSPORE_FILE_NAME}.tar.gz"
MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}/MindSpore/lite/release/linux/${MINDSPORE_FILE}"
rm -rf build/${MINDSPORE_FILE_NAME}
rm -rf lib
mkdir -p build/${MINDSPORE_FILE_NAME}
mkdir -p lib
if [ -n "$TARBALL" ]; then
cp ${TARBALL} ${BASEPATH}/build/${MINDSPORE_FILE}
fi
if [ ! -e ${BASEPATH}/build/${MINDSPORE_FILE} ]; then
wget -c -O ${BASEPATH}/build/${MINDSPORE_FILE} --no-check-certificate ${MINDSPORE_LITE_DOWNLOAD_URL}
fi
tar xzvf ${BASEPATH}/build/${MINDSPORE_FILE} -C ${BASEPATH}/build/${MINDSPORE_FILE_NAME} --strip-components=1
cp -r ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/runtime/lib/* ${BASEPATH}/lib
cp ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/runtime/third_party/libjpeg-turbo/lib/*so* ${BASEPATH}/lib
cd ${BASEPATH}/ || exit
mvn package

View File

@ -0,0 +1,51 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.mindspore.flclient.demo</groupId>
<artifactId>quick_start_flclient</artifactId>
<version>1.0</version>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>com.mindspore.flclient</groupId>
<artifactId>mindspore-lite-flclient</artifactId>
<version>1.0</version>
<scope>system</scope>
<systemPath>${project.basedir}/lib/mindspore-lite-java-flclient.jar</systemPath>
</dependency>
</dependencies>
<build>
<finalName>${project.name}</finalName>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-assemble</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>

View File

@ -0,0 +1,47 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.demo.albert;
/**
* feature class
*
* @since v1.0
*/
public class Feature {
int[] inputIds;
int[] inputMasks;
int[] tokenIds;
int labelIds;
int seqLen;
/**
* constructor
*
* @param inputIds input id
* @param inputMasks input masks
* @param tokenIds token ids
* @param labelIds label ids
* @param seqLen seq len
*/
public Feature(int[] inputIds, int[] inputMasks, int[] tokenIds, int labelIds, int seqLen) {
this.inputIds = inputIds;
this.inputMasks = inputMasks;
this.tokenIds = tokenIds;
this.labelIds = labelIds;
this.seqLen = seqLen;
}
}

View File

@ -0,0 +1,110 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.demo.common;
import com.mindspore.flclient.model.Callback;
import com.mindspore.flclient.model.CommonUtils;
import com.mindspore.flclient.model.Status;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.MSTensor;
import java.util.List;
import java.util.Optional;
import java.util.logging.Logger;
/**
* Defining the Callback calculate classifier model.
*
* @since v1.0
*/
public class ClassifierAccuracyCallback extends Callback {
private static final Logger LOGGER = Logger.getLogger(ClassifierAccuracyCallback.class.toString());
private final int numOfClass;
private final int batchSize;
private final List<Integer> targetLabels;
private float accuracy;
/**
* Defining a constructor of ClassifierAccuracyCallback.
*/
public ClassifierAccuracyCallback(LiteSession session, int batchSize, int numOfClass, List<Integer> targetLabels) {
super(session);
this.batchSize = batchSize;
this.numOfClass = numOfClass;
this.targetLabels = targetLabels;
}
/**
* Get eval accuracy.
*
* @return accuracy.
*/
public float getAccuracy() {
return accuracy;
}
@Override
public Status stepBegin() {
return Status.SUCCESS;
}
@Override
public Status stepEnd() {
Status status = calAccuracy();
if (status != Status.SUCCESS) {
return status;
}
steps++;
return Status.SUCCESS;
}
@Override
public Status epochBegin() {
return Status.SUCCESS;
}
@Override
public Status epochEnd() {
LOGGER.info("average accuracy:" + steps + ",acc is:" + accuracy / steps);
accuracy = accuracy / steps;
steps = 0;
return Status.SUCCESS;
}
private Status calAccuracy() {
if (targetLabels == null || targetLabels.isEmpty()) {
LOGGER.severe("labels cannot be null");
return Status.NULLPTR;
}
Optional<MSTensor> outputTensor = searchOutputsForSize(batchSize * numOfClass);
if (!outputTensor.isPresent()) {
return Status.NULLPTR;
}
float[] scores = outputTensor.get().getFloatData();
int hitCounts = 0;
for (int b = 0; b < batchSize; b++) {
int predictIdx = CommonUtils.getMaxScoreIndex(scores, numOfClass * b, numOfClass * b + numOfClass);
if (targetLabels.get(b + steps * batchSize) == predictIdx) {
hitCounts += 1;
}
}
accuracy += ((float) (hitCounts) / batchSize);
LOGGER.info("steps:" + steps + ",acc is:" + (float) (hitCounts) / batchSize);
return Status.SUCCESS;
}
}

View File

@ -0,0 +1,110 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.demo.common;
import com.mindspore.flclient.model.Callback;
import com.mindspore.flclient.model.Status;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.MSTensor;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.logging.Logger;
/**
* Defining the Callback get model predict result.
*
* @since v1.0
*/
public class PredictCallback extends Callback {
private static final Logger LOGGER = Logger.getLogger(PredictCallback.class.toString());
private final List<Integer> predictResults = new ArrayList<>();
private final int numOfClass;
private final int batchSize;
/**
* Defining a constructor of predict callback.
*/
public PredictCallback(LiteSession session, int batchSize, int numOfClass) {
super(session);
this.batchSize = batchSize;
this.numOfClass = numOfClass;
}
public static int getMaxScoreIndex(float[] scores, int start, int end) {
if (scores != null && scores.length != 0) {
if (start < scores.length && start >= 0 && end <= scores.length && end >= 0) {
float maxScore = scores[start];
int maxIdx = start;
for (int i = start; i < end; ++i) {
if (scores[i] > maxScore) {
maxIdx = i;
maxScore = scores[i];
}
}
return maxIdx - start;
} else {
LOGGER.severe("start,end cannot out of scores length");
return -1;
}
} else {
LOGGER.severe("scores cannot be empty");
return -1;
}
}
/**
* Get predict results.
*
* @return predict result.
*/
public List<Integer> getPredictResults() {
return predictResults;
}
@Override
public Status stepBegin() {
return Status.SUCCESS;
}
@Override
public Status stepEnd() {
Optional<MSTensor> outputTensor = searchOutputsForSize(batchSize * numOfClass);
if (!outputTensor.isPresent()) {
return Status.FAILED;
}
float[] scores = outputTensor.get().getFloatData();
for (int b = 0; b < batchSize; b++) {
int predictIdx = getMaxScoreIndex(scores, numOfClass * b, numOfClass * b + numOfClass);
predictResults.add(predictIdx);
}
return Status.SUCCESS;
}
@Override
public Status epochBegin() {
return Status.SUCCESS;
}
@Override
public Status epochEnd() {
return Status.SUCCESS;
}
}

View File

@ -0,0 +1,114 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.demo.lenet;
import com.mindspore.flclient.demo.common.ClassifierAccuracyCallback;
import com.mindspore.flclient.demo.common.PredictCallback;
import com.mindspore.flclient.model.Callback;
import com.mindspore.flclient.model.Client;
import com.mindspore.flclient.model.ClientManager;
import com.mindspore.flclient.model.DataSet;
import com.mindspore.flclient.model.LossCallback;
import com.mindspore.flclient.model.RunType;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
/**
* Defining the lenet client base class.
*
* @since v1.0
*/
public class LenetClient extends Client {
private static final Logger LOGGER = Logger.getLogger(LenetClient.class.toString());
private static final int NUM_OF_CLASS = 62;
static {
ClientManager.registerClient(new LenetClient());
}
@Override
public List<Callback> initCallbacks(RunType runType, DataSet dataSet) {
List<Callback> callbacks = new ArrayList<>();
if (runType == RunType.TRAINMODE) {
Callback lossCallback = new LossCallback(trainSession);
callbacks.add(lossCallback);
} else if (runType == RunType.EVALMODE) {
if (dataSet instanceof LenetDataSet) {
Callback evalCallback = new ClassifierAccuracyCallback(trainSession, dataSet.batchSize, NUM_OF_CLASS,
((LenetDataSet) dataSet).getTargetLabels());
callbacks.add(evalCallback);
}
} else {
Callback inferCallback = new PredictCallback(trainSession, dataSet.batchSize, NUM_OF_CLASS);
callbacks.add(inferCallback);
}
return callbacks;
}
@Override
public Map<RunType, Integer> initDataSets(Map<RunType, List<String>> files) {
Map<RunType, Integer> sampleCounts = new HashMap<>();
List<String> trainFiles = files.getOrDefault(RunType.TRAINMODE, null);
if (trainFiles != null) {
DataSet trainDataSet = new LenetDataSet(NUM_OF_CLASS);
trainDataSet.init(trainFiles);
dataSets.put(RunType.TRAINMODE, trainDataSet);
sampleCounts.put(RunType.TRAINMODE, trainDataSet.sampleSize);
}
List<String> evalFiles = files.getOrDefault(RunType.EVALMODE, null);
if (evalFiles != null) {
LenetDataSet evalDataSet = new LenetDataSet(NUM_OF_CLASS);
evalDataSet.init(evalFiles);
dataSets.put(RunType.EVALMODE, evalDataSet);
sampleCounts.put(RunType.EVALMODE, evalDataSet.sampleSize);
}
List<String> inferFiles = files.getOrDefault(RunType.INFERMODE, null);
if (inferFiles != null) {
DataSet inferDataSet = new LenetDataSet(NUM_OF_CLASS);
inferDataSet.init(inferFiles);
dataSets.put(RunType.INFERMODE, inferDataSet);
sampleCounts.put(RunType.INFERMODE, inferDataSet.sampleSize);
}
return sampleCounts;
}
@Override
public float getEvalAccuracy(List<Callback> evalCallbacks) {
for (Callback callBack : evalCallbacks) {
if (callBack instanceof ClassifierAccuracyCallback) {
return ((ClassifierAccuracyCallback) callBack).getAccuracy();
}
}
LOGGER.severe("don not find accuracy related callback");
return Float.NaN;
}
@Override
public List<Integer> getInferResult(List<Callback> inferCallbacks) {
for (Callback callBack : inferCallbacks) {
if (callBack instanceof PredictCallback) {
return ((PredictCallback) callBack).getPredictResults();
}
}
LOGGER.severe("don not find accuracy related callback");
return new ArrayList<>();
}
}

View File

@ -0,0 +1,201 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.demo.lenet;
import com.mindspore.flclient.model.DataSet;
import com.mindspore.flclient.model.Status;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;
/**
* Defining the minist dataset for lenet.
*
* @since v1.0
*/
public class LenetDataSet extends DataSet {
private static final Logger LOGGER = Logger.getLogger(LenetDataSet.class.toString());
private static final int IMAGE_SIZE = 32 * 32 * 3;
private static final int FLOAT_BYTE_SIZE = 4;
private byte[] imageArray;
private int[] labelArray;
private final int numOfClass;
private List<Integer> targetLabels;
/**
* Defining a constructor of lenet dataset.
*/
public LenetDataSet(int numOfClass) {
this.numOfClass = numOfClass;
}
/**
* Get dataset labels.
*
* @return dataset target labels.
*/
public List<Integer> getTargetLabels() {
return targetLabels;
}
@Override
public void fillInputBuffer(List<ByteBuffer> inputsBuffer, int batchIdx) {
// infer,train,eval model is same one
if (inputsBuffer.size() != 2) {
LOGGER.severe("input size error");
return;
}
if (batchIdx > batchNum) {
LOGGER.severe("fill model image input failed");
return;
}
for (ByteBuffer inputBuffer : inputsBuffer) {
inputBuffer.clear();
}
ByteBuffer imageBuffer = inputsBuffer.get(0);
ByteBuffer labelIdBuffer = inputsBuffer.get(1);
int imageInputBytes = IMAGE_SIZE * batchSize * Float.BYTES;
for (int i = 0; i < imageInputBytes; i++) {
imageBuffer.put(imageArray[batchIdx * imageInputBytes + i]);
}
if (labelArray == null) {
return;
}
int labelSize = batchSize * numOfClass;
if ((batchIdx + 1) * labelSize - 1 >= labelArray.length) {
LOGGER.severe("fill model label input failed");
return;
}
labelIdBuffer.clear();
for (int i = 0; i < labelSize; i++) {
labelIdBuffer.putFloat(labelArray[batchIdx * labelSize + i]);
}
}
@Override
public void shuffle() {
}
@Override
public void padding() {
if (labelArray == null) // infer model
{
labelArray = new int[imageArray.length * numOfClass / (IMAGE_SIZE * Float.BYTES)];
Arrays.fill(labelArray, 0);
}
int curSize = labelArray.length / numOfClass;
int modSize = curSize - curSize / batchSize * batchSize;
int padSize = modSize != 0 ? batchSize * numOfClass - modSize : 0;
if (padSize != 0) {
int[] padLabelArray = new int[labelArray.length + padSize * numOfClass];
byte[] padImageArray = new byte[imageArray.length + padSize * IMAGE_SIZE * Float.BYTES];
System.arraycopy(labelArray, 0, padLabelArray, 0, labelArray.length);
System.arraycopy(imageArray, 0, padImageArray, 0, imageArray.length);
for (int i = 0; i < padSize; i++) {
int idx = (int) (Math.random() * curSize);
System.arraycopy(labelArray, idx * numOfClass, padLabelArray, labelArray.length + i * numOfClass,
numOfClass);
System.arraycopy(imageArray, idx * IMAGE_SIZE * Float.BYTES, padImageArray,
padImageArray.length + i * IMAGE_SIZE * Float.BYTES, IMAGE_SIZE * Float.BYTES);
}
labelArray = padLabelArray;
imageArray = padImageArray;
}
sampleSize = curSize + padSize;
batchNum = sampleSize / batchSize;
setPredictLabels(labelArray);
LOGGER.info("total samples:" + sampleSize);
LOGGER.info("total batchNum:" + batchNum);
}
private void setPredictLabels(int[] labelArray) {
int labels_num = labelArray.length / numOfClass;
targetLabels = new ArrayList<>(labels_num);
for (int i = 0; i < labels_num; i++) {
int label = getMaxIndex(labelArray, numOfClass * i, numOfClass * (i + 1));
if (label == -1) {
LOGGER.severe("get max index failed");
}
targetLabels.add(label);
}
}
private int getMaxIndex(int[] nums, int begin, int end) {
for (int i = begin; i < end; i++) {
if (nums[i] == 1) {
return i - begin;
}
}
return -1;
}
public static byte[] readBinFile(String dataFile) {
if (dataFile == null || dataFile.isEmpty()) {
LOGGER.severe("file cannot be empty");
return new byte[0];
}
// read train file
Path path = Paths.get(dataFile);
byte[] data = new byte[0];
try {
data = Files.readAllBytes(path);
} catch (IOException e) {
LOGGER.severe("read data file failed,please check data file path");
}
return data;
}
@Override
public Status dataPreprocess(List<String> files) {
String labelFile = "";
String imageFile;
if (files.size() == 2) {
imageFile = files.get(0);
labelFile = files.get(1);
} else if (files.size() == 1) {
imageFile = files.get(0);
} else {
LOGGER.severe("files size error");
return Status.FAILED;
}
imageArray = readBinFile(imageFile);
if (labelFile != null && !labelFile.isEmpty()) {
byte[] labelByteArray = readBinFile(labelFile);
targetLabels = new ArrayList<>(labelByteArray.length / FLOAT_BYTE_SIZE);
// model labels use one hot
labelArray = new int[labelByteArray.length / FLOAT_BYTE_SIZE * numOfClass];
Arrays.fill(labelArray, 0);
int offset = 0;
for (int i = 0; i < labelByteArray.length; i += FLOAT_BYTE_SIZE) {
labelArray[offset * numOfClass + labelByteArray[i]] = 1;
offset++;
}
} else {
labelArray = null; // labelArray may be initialized from train
}
sampleSize = imageArray.length / IMAGE_SIZE / FLOAT_BYTE_SIZE;
return Status.SUCCESS;
}
}

View File

@ -152,6 +152,10 @@ public class SessionUtil {
logger.severe(Common.addTag("trainSession cannot be null"));
return;
}
if(trainSession.getSessionPtr() == 0) {
logger.warning(Common.addTag("trainSession pointer has already free"));
return;
}
trainSession.free();
}
}