forked from mindspore-Ecosystem/mindspore
add flclient lenet example
This commit is contained in:
parent
901d81d16d
commit
b9e716ab7a
|
@ -0,0 +1,73 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
display_usage()
|
||||
{
|
||||
echo -e "\nUsage: build.sh [-r release.tar.gz]\n"
|
||||
}
|
||||
|
||||
checkopts()
|
||||
{
|
||||
TARBALL=""
|
||||
while getopts 'r:' opt
|
||||
do
|
||||
case "${opt}" in
|
||||
r)
|
||||
TARBALL=$OPTARG
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option ${opt}!"
|
||||
display_usage
|
||||
exit 1
|
||||
esac
|
||||
done
|
||||
}
|
||||
|
||||
checkopts "$@"
|
||||
|
||||
BASEPATH=$(cd "$(dirname $0)" || exit; pwd)
|
||||
get_version() {
|
||||
VERSION_MAJOR=$(grep "const int ms_version_major =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
|
||||
VERSION_MINOR=$(grep "const int ms_version_minor =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
|
||||
VERSION_REVISION=$(grep "const int ms_version_revision =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
|
||||
VERSION_STR=${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION}
|
||||
}
|
||||
get_version
|
||||
|
||||
MINDSPORE_FILE_NAME="mindspore-lite-${VERSION_STR}-linux-x64"
|
||||
MINDSPORE_FILE="${MINDSPORE_FILE_NAME}.tar.gz"
|
||||
MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}/MindSpore/lite/release/linux/${MINDSPORE_FILE}"
|
||||
|
||||
rm -rf build/${MINDSPORE_FILE_NAME}
|
||||
rm -rf lib
|
||||
mkdir -p build/${MINDSPORE_FILE_NAME}
|
||||
mkdir -p lib
|
||||
|
||||
if [ -n "$TARBALL" ]; then
|
||||
cp ${TARBALL} ${BASEPATH}/build/${MINDSPORE_FILE}
|
||||
fi
|
||||
|
||||
if [ ! -e ${BASEPATH}/build/${MINDSPORE_FILE} ]; then
|
||||
wget -c -O ${BASEPATH}/build/${MINDSPORE_FILE} --no-check-certificate ${MINDSPORE_LITE_DOWNLOAD_URL}
|
||||
fi
|
||||
|
||||
tar xzvf ${BASEPATH}/build/${MINDSPORE_FILE} -C ${BASEPATH}/build/${MINDSPORE_FILE_NAME} --strip-components=1
|
||||
|
||||
cp -r ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/runtime/lib/* ${BASEPATH}/lib
|
||||
cp ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/runtime/third_party/libjpeg-turbo/lib/*so* ${BASEPATH}/lib
|
||||
cd ${BASEPATH}/ || exit
|
||||
|
||||
mvn package
|
|
@ -0,0 +1,51 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>com.mindspore.flclient.demo</groupId>
|
||||
<artifactId>quick_start_flclient</artifactId>
|
||||
<version>1.0</version>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>8</maven.compiler.source>
|
||||
<maven.compiler.target>8</maven.compiler.target>
|
||||
</properties>
|
||||
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.mindspore.flclient</groupId>
|
||||
<artifactId>mindspore-lite-flclient</artifactId>
|
||||
<version>1.0</version>
|
||||
<scope>system</scope>
|
||||
<systemPath>${project.basedir}/lib/mindspore-lite-java-flclient.jar</systemPath>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<finalName>${project.name}</finalName>
|
||||
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-assembly-plugin</artifactId>
|
||||
<configuration>
|
||||
<descriptorRefs>
|
||||
<descriptorRef>jar-with-dependencies</descriptorRef>
|
||||
</descriptorRefs>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>make-assemble</id>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>single</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</project>
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.demo.albert;
|
||||
|
||||
/**
|
||||
* feature class
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class Feature {
|
||||
int[] inputIds;
|
||||
int[] inputMasks;
|
||||
int[] tokenIds;
|
||||
int labelIds;
|
||||
int seqLen;
|
||||
|
||||
/**
|
||||
* constructor
|
||||
*
|
||||
* @param inputIds input id
|
||||
* @param inputMasks input masks
|
||||
* @param tokenIds token ids
|
||||
* @param labelIds label ids
|
||||
* @param seqLen seq len
|
||||
*/
|
||||
public Feature(int[] inputIds, int[] inputMasks, int[] tokenIds, int labelIds, int seqLen) {
|
||||
this.inputIds = inputIds;
|
||||
this.inputMasks = inputMasks;
|
||||
this.tokenIds = tokenIds;
|
||||
this.labelIds = labelIds;
|
||||
this.seqLen = seqLen;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.demo.common;
|
||||
|
||||
import com.mindspore.flclient.model.Callback;
|
||||
import com.mindspore.flclient.model.CommonUtils;
|
||||
import com.mindspore.flclient.model.Status;
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.MSTensor;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Defining the Callback calculate classifier model.
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class ClassifierAccuracyCallback extends Callback {
|
||||
private static final Logger LOGGER = Logger.getLogger(ClassifierAccuracyCallback.class.toString());
|
||||
private final int numOfClass;
|
||||
private final int batchSize;
|
||||
private final List<Integer> targetLabels;
|
||||
private float accuracy;
|
||||
|
||||
/**
|
||||
* Defining a constructor of ClassifierAccuracyCallback.
|
||||
*/
|
||||
public ClassifierAccuracyCallback(LiteSession session, int batchSize, int numOfClass, List<Integer> targetLabels) {
|
||||
super(session);
|
||||
this.batchSize = batchSize;
|
||||
this.numOfClass = numOfClass;
|
||||
this.targetLabels = targetLabels;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get eval accuracy.
|
||||
*
|
||||
* @return accuracy.
|
||||
*/
|
||||
public float getAccuracy() {
|
||||
return accuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status stepBegin() {
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status stepEnd() {
|
||||
Status status = calAccuracy();
|
||||
if (status != Status.SUCCESS) {
|
||||
return status;
|
||||
}
|
||||
steps++;
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status epochBegin() {
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status epochEnd() {
|
||||
LOGGER.info("average accuracy:" + steps + ",acc is:" + accuracy / steps);
|
||||
accuracy = accuracy / steps;
|
||||
steps = 0;
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
private Status calAccuracy() {
|
||||
if (targetLabels == null || targetLabels.isEmpty()) {
|
||||
LOGGER.severe("labels cannot be null");
|
||||
return Status.NULLPTR;
|
||||
}
|
||||
Optional<MSTensor> outputTensor = searchOutputsForSize(batchSize * numOfClass);
|
||||
if (!outputTensor.isPresent()) {
|
||||
return Status.NULLPTR;
|
||||
}
|
||||
float[] scores = outputTensor.get().getFloatData();
|
||||
int hitCounts = 0;
|
||||
for (int b = 0; b < batchSize; b++) {
|
||||
int predictIdx = CommonUtils.getMaxScoreIndex(scores, numOfClass * b, numOfClass * b + numOfClass);
|
||||
if (targetLabels.get(b + steps * batchSize) == predictIdx) {
|
||||
hitCounts += 1;
|
||||
}
|
||||
}
|
||||
accuracy += ((float) (hitCounts) / batchSize);
|
||||
LOGGER.info("steps:" + steps + ",acc is:" + (float) (hitCounts) / batchSize);
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.demo.common;
|
||||
|
||||
import com.mindspore.flclient.model.Callback;
|
||||
import com.mindspore.flclient.model.Status;
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.MSTensor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Defining the Callback get model predict result.
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class PredictCallback extends Callback {
|
||||
private static final Logger LOGGER = Logger.getLogger(PredictCallback.class.toString());
|
||||
private final List<Integer> predictResults = new ArrayList<>();
|
||||
private final int numOfClass;
|
||||
private final int batchSize;
|
||||
|
||||
/**
|
||||
* Defining a constructor of predict callback.
|
||||
*/
|
||||
public PredictCallback(LiteSession session, int batchSize, int numOfClass) {
|
||||
super(session);
|
||||
this.batchSize = batchSize;
|
||||
this.numOfClass = numOfClass;
|
||||
}
|
||||
|
||||
public static int getMaxScoreIndex(float[] scores, int start, int end) {
|
||||
if (scores != null && scores.length != 0) {
|
||||
if (start < scores.length && start >= 0 && end <= scores.length && end >= 0) {
|
||||
float maxScore = scores[start];
|
||||
int maxIdx = start;
|
||||
|
||||
for (int i = start; i < end; ++i) {
|
||||
if (scores[i] > maxScore) {
|
||||
maxIdx = i;
|
||||
maxScore = scores[i];
|
||||
}
|
||||
}
|
||||
|
||||
return maxIdx - start;
|
||||
} else {
|
||||
LOGGER.severe("start,end cannot out of scores length");
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
LOGGER.severe("scores cannot be empty");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get predict results.
|
||||
*
|
||||
* @return predict result.
|
||||
*/
|
||||
public List<Integer> getPredictResults() {
|
||||
return predictResults;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status stepBegin() {
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status stepEnd() {
|
||||
Optional<MSTensor> outputTensor = searchOutputsForSize(batchSize * numOfClass);
|
||||
if (!outputTensor.isPresent()) {
|
||||
return Status.FAILED;
|
||||
}
|
||||
float[] scores = outputTensor.get().getFloatData();
|
||||
for (int b = 0; b < batchSize; b++) {
|
||||
int predictIdx = getMaxScoreIndex(scores, numOfClass * b, numOfClass * b + numOfClass);
|
||||
predictResults.add(predictIdx);
|
||||
}
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status epochBegin() {
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status epochEnd() {
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,114 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.demo.lenet;
|
||||
|
||||
import com.mindspore.flclient.demo.common.ClassifierAccuracyCallback;
|
||||
import com.mindspore.flclient.demo.common.PredictCallback;
|
||||
import com.mindspore.flclient.model.Callback;
|
||||
import com.mindspore.flclient.model.Client;
|
||||
import com.mindspore.flclient.model.ClientManager;
|
||||
import com.mindspore.flclient.model.DataSet;
|
||||
import com.mindspore.flclient.model.LossCallback;
|
||||
import com.mindspore.flclient.model.RunType;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Defining the lenet client base class.
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class LenetClient extends Client {
|
||||
private static final Logger LOGGER = Logger.getLogger(LenetClient.class.toString());
|
||||
private static final int NUM_OF_CLASS = 62;
|
||||
|
||||
static {
|
||||
ClientManager.registerClient(new LenetClient());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Callback> initCallbacks(RunType runType, DataSet dataSet) {
|
||||
List<Callback> callbacks = new ArrayList<>();
|
||||
if (runType == RunType.TRAINMODE) {
|
||||
Callback lossCallback = new LossCallback(trainSession);
|
||||
callbacks.add(lossCallback);
|
||||
} else if (runType == RunType.EVALMODE) {
|
||||
if (dataSet instanceof LenetDataSet) {
|
||||
Callback evalCallback = new ClassifierAccuracyCallback(trainSession, dataSet.batchSize, NUM_OF_CLASS,
|
||||
((LenetDataSet) dataSet).getTargetLabels());
|
||||
callbacks.add(evalCallback);
|
||||
}
|
||||
} else {
|
||||
Callback inferCallback = new PredictCallback(trainSession, dataSet.batchSize, NUM_OF_CLASS);
|
||||
callbacks.add(inferCallback);
|
||||
}
|
||||
return callbacks;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<RunType, Integer> initDataSets(Map<RunType, List<String>> files) {
|
||||
Map<RunType, Integer> sampleCounts = new HashMap<>();
|
||||
List<String> trainFiles = files.getOrDefault(RunType.TRAINMODE, null);
|
||||
if (trainFiles != null) {
|
||||
DataSet trainDataSet = new LenetDataSet(NUM_OF_CLASS);
|
||||
trainDataSet.init(trainFiles);
|
||||
dataSets.put(RunType.TRAINMODE, trainDataSet);
|
||||
sampleCounts.put(RunType.TRAINMODE, trainDataSet.sampleSize);
|
||||
}
|
||||
List<String> evalFiles = files.getOrDefault(RunType.EVALMODE, null);
|
||||
if (evalFiles != null) {
|
||||
LenetDataSet evalDataSet = new LenetDataSet(NUM_OF_CLASS);
|
||||
evalDataSet.init(evalFiles);
|
||||
dataSets.put(RunType.EVALMODE, evalDataSet);
|
||||
sampleCounts.put(RunType.EVALMODE, evalDataSet.sampleSize);
|
||||
}
|
||||
List<String> inferFiles = files.getOrDefault(RunType.INFERMODE, null);
|
||||
if (inferFiles != null) {
|
||||
DataSet inferDataSet = new LenetDataSet(NUM_OF_CLASS);
|
||||
inferDataSet.init(inferFiles);
|
||||
dataSets.put(RunType.INFERMODE, inferDataSet);
|
||||
sampleCounts.put(RunType.INFERMODE, inferDataSet.sampleSize);
|
||||
}
|
||||
return sampleCounts;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getEvalAccuracy(List<Callback> evalCallbacks) {
|
||||
for (Callback callBack : evalCallbacks) {
|
||||
if (callBack instanceof ClassifierAccuracyCallback) {
|
||||
return ((ClassifierAccuracyCallback) callBack).getAccuracy();
|
||||
}
|
||||
}
|
||||
LOGGER.severe("don not find accuracy related callback");
|
||||
return Float.NaN;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Integer> getInferResult(List<Callback> inferCallbacks) {
|
||||
for (Callback callBack : inferCallbacks) {
|
||||
if (callBack instanceof PredictCallback) {
|
||||
return ((PredictCallback) callBack).getPredictResults();
|
||||
}
|
||||
}
|
||||
LOGGER.severe("don not find accuracy related callback");
|
||||
return new ArrayList<>();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,201 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.demo.lenet;
|
||||
|
||||
import com.mindspore.flclient.model.DataSet;
|
||||
import com.mindspore.flclient.model.Status;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Defining the minist dataset for lenet.
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class LenetDataSet extends DataSet {
|
||||
private static final Logger LOGGER = Logger.getLogger(LenetDataSet.class.toString());
|
||||
private static final int IMAGE_SIZE = 32 * 32 * 3;
|
||||
private static final int FLOAT_BYTE_SIZE = 4;
|
||||
private byte[] imageArray;
|
||||
private int[] labelArray;
|
||||
private final int numOfClass;
|
||||
private List<Integer> targetLabels;
|
||||
|
||||
/**
|
||||
* Defining a constructor of lenet dataset.
|
||||
*/
|
||||
public LenetDataSet(int numOfClass) {
|
||||
this.numOfClass = numOfClass;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get dataset labels.
|
||||
*
|
||||
* @return dataset target labels.
|
||||
*/
|
||||
public List<Integer> getTargetLabels() {
|
||||
return targetLabels;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillInputBuffer(List<ByteBuffer> inputsBuffer, int batchIdx) {
|
||||
// infer,train,eval model is same one
|
||||
if (inputsBuffer.size() != 2) {
|
||||
LOGGER.severe("input size error");
|
||||
return;
|
||||
}
|
||||
if (batchIdx > batchNum) {
|
||||
LOGGER.severe("fill model image input failed");
|
||||
return;
|
||||
}
|
||||
for (ByteBuffer inputBuffer : inputsBuffer) {
|
||||
inputBuffer.clear();
|
||||
}
|
||||
ByteBuffer imageBuffer = inputsBuffer.get(0);
|
||||
ByteBuffer labelIdBuffer = inputsBuffer.get(1);
|
||||
int imageInputBytes = IMAGE_SIZE * batchSize * Float.BYTES;
|
||||
for (int i = 0; i < imageInputBytes; i++) {
|
||||
imageBuffer.put(imageArray[batchIdx * imageInputBytes + i]);
|
||||
}
|
||||
if (labelArray == null) {
|
||||
return;
|
||||
}
|
||||
int labelSize = batchSize * numOfClass;
|
||||
if ((batchIdx + 1) * labelSize - 1 >= labelArray.length) {
|
||||
LOGGER.severe("fill model label input failed");
|
||||
return;
|
||||
}
|
||||
labelIdBuffer.clear();
|
||||
for (int i = 0; i < labelSize; i++) {
|
||||
labelIdBuffer.putFloat(labelArray[batchIdx * labelSize + i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void shuffle() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void padding() {
|
||||
if (labelArray == null) // infer model
|
||||
{
|
||||
labelArray = new int[imageArray.length * numOfClass / (IMAGE_SIZE * Float.BYTES)];
|
||||
Arrays.fill(labelArray, 0);
|
||||
}
|
||||
int curSize = labelArray.length / numOfClass;
|
||||
int modSize = curSize - curSize / batchSize * batchSize;
|
||||
int padSize = modSize != 0 ? batchSize * numOfClass - modSize : 0;
|
||||
if (padSize != 0) {
|
||||
int[] padLabelArray = new int[labelArray.length + padSize * numOfClass];
|
||||
byte[] padImageArray = new byte[imageArray.length + padSize * IMAGE_SIZE * Float.BYTES];
|
||||
System.arraycopy(labelArray, 0, padLabelArray, 0, labelArray.length);
|
||||
System.arraycopy(imageArray, 0, padImageArray, 0, imageArray.length);
|
||||
for (int i = 0; i < padSize; i++) {
|
||||
int idx = (int) (Math.random() * curSize);
|
||||
System.arraycopy(labelArray, idx * numOfClass, padLabelArray, labelArray.length + i * numOfClass,
|
||||
numOfClass);
|
||||
System.arraycopy(imageArray, idx * IMAGE_SIZE * Float.BYTES, padImageArray,
|
||||
padImageArray.length + i * IMAGE_SIZE * Float.BYTES, IMAGE_SIZE * Float.BYTES);
|
||||
}
|
||||
labelArray = padLabelArray;
|
||||
imageArray = padImageArray;
|
||||
}
|
||||
sampleSize = curSize + padSize;
|
||||
batchNum = sampleSize / batchSize;
|
||||
setPredictLabels(labelArray);
|
||||
LOGGER.info("total samples:" + sampleSize);
|
||||
LOGGER.info("total batchNum:" + batchNum);
|
||||
}
|
||||
|
||||
private void setPredictLabels(int[] labelArray) {
|
||||
int labels_num = labelArray.length / numOfClass;
|
||||
targetLabels = new ArrayList<>(labels_num);
|
||||
for (int i = 0; i < labels_num; i++) {
|
||||
int label = getMaxIndex(labelArray, numOfClass * i, numOfClass * (i + 1));
|
||||
if (label == -1) {
|
||||
LOGGER.severe("get max index failed");
|
||||
}
|
||||
targetLabels.add(label);
|
||||
}
|
||||
}
|
||||
|
||||
private int getMaxIndex(int[] nums, int begin, int end) {
|
||||
for (int i = begin; i < end; i++) {
|
||||
if (nums[i] == 1) {
|
||||
return i - begin;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
public static byte[] readBinFile(String dataFile) {
|
||||
if (dataFile == null || dataFile.isEmpty()) {
|
||||
LOGGER.severe("file cannot be empty");
|
||||
return new byte[0];
|
||||
}
|
||||
// read train file
|
||||
Path path = Paths.get(dataFile);
|
||||
byte[] data = new byte[0];
|
||||
try {
|
||||
data = Files.readAllBytes(path);
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe("read data file failed,please check data file path");
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status dataPreprocess(List<String> files) {
|
||||
String labelFile = "";
|
||||
String imageFile;
|
||||
if (files.size() == 2) {
|
||||
imageFile = files.get(0);
|
||||
labelFile = files.get(1);
|
||||
} else if (files.size() == 1) {
|
||||
imageFile = files.get(0);
|
||||
} else {
|
||||
LOGGER.severe("files size error");
|
||||
return Status.FAILED;
|
||||
}
|
||||
imageArray = readBinFile(imageFile);
|
||||
if (labelFile != null && !labelFile.isEmpty()) {
|
||||
byte[] labelByteArray = readBinFile(labelFile);
|
||||
targetLabels = new ArrayList<>(labelByteArray.length / FLOAT_BYTE_SIZE);
|
||||
// model labels use one hot
|
||||
labelArray = new int[labelByteArray.length / FLOAT_BYTE_SIZE * numOfClass];
|
||||
Arrays.fill(labelArray, 0);
|
||||
int offset = 0;
|
||||
for (int i = 0; i < labelByteArray.length; i += FLOAT_BYTE_SIZE) {
|
||||
labelArray[offset * numOfClass + labelByteArray[i]] = 1;
|
||||
offset++;
|
||||
}
|
||||
} else {
|
||||
labelArray = null; // labelArray may be initialized from train
|
||||
}
|
||||
sampleSize = imageArray.length / IMAGE_SIZE / FLOAT_BYTE_SIZE;
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
}
|
|
@ -152,6 +152,10 @@ public class SessionUtil {
|
|||
logger.severe(Common.addTag("trainSession cannot be null"));
|
||||
return;
|
||||
}
|
||||
if(trainSession.getSessionPtr() == 0) {
|
||||
logger.warning(Common.addTag("trainSession pointer has already free"));
|
||||
return;
|
||||
}
|
||||
trainSession.free();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue