forked from mindspore-Ecosystem/mindspore
add flclient lenet example
This commit is contained in:
parent
901d81d16d
commit
b9e716ab7a
|
@ -0,0 +1,73 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
display_usage()
|
||||||
|
{
|
||||||
|
echo -e "\nUsage: build.sh [-r release.tar.gz]\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
checkopts()
|
||||||
|
{
|
||||||
|
TARBALL=""
|
||||||
|
while getopts 'r:' opt
|
||||||
|
do
|
||||||
|
case "${opt}" in
|
||||||
|
r)
|
||||||
|
TARBALL=$OPTARG
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option ${opt}!"
|
||||||
|
display_usage
|
||||||
|
exit 1
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
checkopts "$@"
|
||||||
|
|
||||||
|
BASEPATH=$(cd "$(dirname $0)" || exit; pwd)
|
||||||
|
get_version() {
|
||||||
|
VERSION_MAJOR=$(grep "const int ms_version_major =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
|
||||||
|
VERSION_MINOR=$(grep "const int ms_version_minor =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
|
||||||
|
VERSION_REVISION=$(grep "const int ms_version_revision =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
|
||||||
|
VERSION_STR=${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION}
|
||||||
|
}
|
||||||
|
get_version
|
||||||
|
|
||||||
|
MINDSPORE_FILE_NAME="mindspore-lite-${VERSION_STR}-linux-x64"
|
||||||
|
MINDSPORE_FILE="${MINDSPORE_FILE_NAME}.tar.gz"
|
||||||
|
MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}/MindSpore/lite/release/linux/${MINDSPORE_FILE}"
|
||||||
|
|
||||||
|
rm -rf build/${MINDSPORE_FILE_NAME}
|
||||||
|
rm -rf lib
|
||||||
|
mkdir -p build/${MINDSPORE_FILE_NAME}
|
||||||
|
mkdir -p lib
|
||||||
|
|
||||||
|
if [ -n "$TARBALL" ]; then
|
||||||
|
cp ${TARBALL} ${BASEPATH}/build/${MINDSPORE_FILE}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -e ${BASEPATH}/build/${MINDSPORE_FILE} ]; then
|
||||||
|
wget -c -O ${BASEPATH}/build/${MINDSPORE_FILE} --no-check-certificate ${MINDSPORE_LITE_DOWNLOAD_URL}
|
||||||
|
fi
|
||||||
|
|
||||||
|
tar xzvf ${BASEPATH}/build/${MINDSPORE_FILE} -C ${BASEPATH}/build/${MINDSPORE_FILE_NAME} --strip-components=1
|
||||||
|
|
||||||
|
cp -r ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/runtime/lib/* ${BASEPATH}/lib
|
||||||
|
cp ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/runtime/third_party/libjpeg-turbo/lib/*so* ${BASEPATH}/lib
|
||||||
|
cd ${BASEPATH}/ || exit
|
||||||
|
|
||||||
|
mvn package
|
|
@ -0,0 +1,51 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<groupId>com.mindspore.flclient.demo</groupId>
|
||||||
|
<artifactId>quick_start_flclient</artifactId>
|
||||||
|
<version>1.0</version>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<maven.compiler.source>8</maven.compiler.source>
|
||||||
|
<maven.compiler.target>8</maven.compiler.target>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.mindspore.flclient</groupId>
|
||||||
|
<artifactId>mindspore-lite-flclient</artifactId>
|
||||||
|
<version>1.0</version>
|
||||||
|
<scope>system</scope>
|
||||||
|
<systemPath>${project.basedir}/lib/mindspore-lite-java-flclient.jar</systemPath>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
<build>
|
||||||
|
<finalName>${project.name}</finalName>
|
||||||
|
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-assembly-plugin</artifactId>
|
||||||
|
<configuration>
|
||||||
|
<descriptorRefs>
|
||||||
|
<descriptorRef>jar-with-dependencies</descriptorRef>
|
||||||
|
</descriptorRefs>
|
||||||
|
</configuration>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>make-assemble</id>
|
||||||
|
<phase>package</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>single</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
</project>
|
|
@ -0,0 +1,47 @@
|
||||||
|
/**
|
||||||
|
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.mindspore.flclient.demo.albert;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* feature class
|
||||||
|
*
|
||||||
|
* @since v1.0
|
||||||
|
*/
|
||||||
|
public class Feature {
|
||||||
|
int[] inputIds;
|
||||||
|
int[] inputMasks;
|
||||||
|
int[] tokenIds;
|
||||||
|
int labelIds;
|
||||||
|
int seqLen;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* constructor
|
||||||
|
*
|
||||||
|
* @param inputIds input id
|
||||||
|
* @param inputMasks input masks
|
||||||
|
* @param tokenIds token ids
|
||||||
|
* @param labelIds label ids
|
||||||
|
* @param seqLen seq len
|
||||||
|
*/
|
||||||
|
public Feature(int[] inputIds, int[] inputMasks, int[] tokenIds, int labelIds, int seqLen) {
|
||||||
|
this.inputIds = inputIds;
|
||||||
|
this.inputMasks = inputMasks;
|
||||||
|
this.tokenIds = tokenIds;
|
||||||
|
this.labelIds = labelIds;
|
||||||
|
this.seqLen = seqLen;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,110 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.mindspore.flclient.demo.common;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.model.Callback;
|
||||||
|
import com.mindspore.flclient.model.CommonUtils;
|
||||||
|
import com.mindspore.flclient.model.Status;
|
||||||
|
import com.mindspore.lite.LiteSession;
|
||||||
|
import com.mindspore.lite.MSTensor;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defining the Callback calculate classifier model.
|
||||||
|
*
|
||||||
|
* @since v1.0
|
||||||
|
*/
|
||||||
|
public class ClassifierAccuracyCallback extends Callback {
|
||||||
|
private static final Logger LOGGER = Logger.getLogger(ClassifierAccuracyCallback.class.toString());
|
||||||
|
private final int numOfClass;
|
||||||
|
private final int batchSize;
|
||||||
|
private final List<Integer> targetLabels;
|
||||||
|
private float accuracy;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defining a constructor of ClassifierAccuracyCallback.
|
||||||
|
*/
|
||||||
|
public ClassifierAccuracyCallback(LiteSession session, int batchSize, int numOfClass, List<Integer> targetLabels) {
|
||||||
|
super(session);
|
||||||
|
this.batchSize = batchSize;
|
||||||
|
this.numOfClass = numOfClass;
|
||||||
|
this.targetLabels = targetLabels;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get eval accuracy.
|
||||||
|
*
|
||||||
|
* @return accuracy.
|
||||||
|
*/
|
||||||
|
public float getAccuracy() {
|
||||||
|
return accuracy;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Status stepBegin() {
|
||||||
|
return Status.SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Status stepEnd() {
|
||||||
|
Status status = calAccuracy();
|
||||||
|
if (status != Status.SUCCESS) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
steps++;
|
||||||
|
return Status.SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Status epochBegin() {
|
||||||
|
return Status.SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Status epochEnd() {
|
||||||
|
LOGGER.info("average accuracy:" + steps + ",acc is:" + accuracy / steps);
|
||||||
|
accuracy = accuracy / steps;
|
||||||
|
steps = 0;
|
||||||
|
return Status.SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Status calAccuracy() {
|
||||||
|
if (targetLabels == null || targetLabels.isEmpty()) {
|
||||||
|
LOGGER.severe("labels cannot be null");
|
||||||
|
return Status.NULLPTR;
|
||||||
|
}
|
||||||
|
Optional<MSTensor> outputTensor = searchOutputsForSize(batchSize * numOfClass);
|
||||||
|
if (!outputTensor.isPresent()) {
|
||||||
|
return Status.NULLPTR;
|
||||||
|
}
|
||||||
|
float[] scores = outputTensor.get().getFloatData();
|
||||||
|
int hitCounts = 0;
|
||||||
|
for (int b = 0; b < batchSize; b++) {
|
||||||
|
int predictIdx = CommonUtils.getMaxScoreIndex(scores, numOfClass * b, numOfClass * b + numOfClass);
|
||||||
|
if (targetLabels.get(b + steps * batchSize) == predictIdx) {
|
||||||
|
hitCounts += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
accuracy += ((float) (hitCounts) / batchSize);
|
||||||
|
LOGGER.info("steps:" + steps + ",acc is:" + (float) (hitCounts) / batchSize);
|
||||||
|
return Status.SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,110 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.mindspore.flclient.demo.common;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.model.Callback;
|
||||||
|
import com.mindspore.flclient.model.Status;
|
||||||
|
import com.mindspore.lite.LiteSession;
|
||||||
|
import com.mindspore.lite.MSTensor;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defining the Callback get model predict result.
|
||||||
|
*
|
||||||
|
* @since v1.0
|
||||||
|
*/
|
||||||
|
public class PredictCallback extends Callback {
|
||||||
|
private static final Logger LOGGER = Logger.getLogger(PredictCallback.class.toString());
|
||||||
|
private final List<Integer> predictResults = new ArrayList<>();
|
||||||
|
private final int numOfClass;
|
||||||
|
private final int batchSize;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defining a constructor of predict callback.
|
||||||
|
*/
|
||||||
|
public PredictCallback(LiteSession session, int batchSize, int numOfClass) {
|
||||||
|
super(session);
|
||||||
|
this.batchSize = batchSize;
|
||||||
|
this.numOfClass = numOfClass;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static int getMaxScoreIndex(float[] scores, int start, int end) {
|
||||||
|
if (scores != null && scores.length != 0) {
|
||||||
|
if (start < scores.length && start >= 0 && end <= scores.length && end >= 0) {
|
||||||
|
float maxScore = scores[start];
|
||||||
|
int maxIdx = start;
|
||||||
|
|
||||||
|
for (int i = start; i < end; ++i) {
|
||||||
|
if (scores[i] > maxScore) {
|
||||||
|
maxIdx = i;
|
||||||
|
maxScore = scores[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return maxIdx - start;
|
||||||
|
} else {
|
||||||
|
LOGGER.severe("start,end cannot out of scores length");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
LOGGER.severe("scores cannot be empty");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get predict results.
|
||||||
|
*
|
||||||
|
* @return predict result.
|
||||||
|
*/
|
||||||
|
public List<Integer> getPredictResults() {
|
||||||
|
return predictResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Status stepBegin() {
|
||||||
|
return Status.SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Status stepEnd() {
|
||||||
|
Optional<MSTensor> outputTensor = searchOutputsForSize(batchSize * numOfClass);
|
||||||
|
if (!outputTensor.isPresent()) {
|
||||||
|
return Status.FAILED;
|
||||||
|
}
|
||||||
|
float[] scores = outputTensor.get().getFloatData();
|
||||||
|
for (int b = 0; b < batchSize; b++) {
|
||||||
|
int predictIdx = getMaxScoreIndex(scores, numOfClass * b, numOfClass * b + numOfClass);
|
||||||
|
predictResults.add(predictIdx);
|
||||||
|
}
|
||||||
|
return Status.SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Status epochBegin() {
|
||||||
|
return Status.SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Status epochEnd() {
|
||||||
|
return Status.SUCCESS;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,114 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.mindspore.flclient.demo.lenet;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.demo.common.ClassifierAccuracyCallback;
|
||||||
|
import com.mindspore.flclient.demo.common.PredictCallback;
|
||||||
|
import com.mindspore.flclient.model.Callback;
|
||||||
|
import com.mindspore.flclient.model.Client;
|
||||||
|
import com.mindspore.flclient.model.ClientManager;
|
||||||
|
import com.mindspore.flclient.model.DataSet;
|
||||||
|
import com.mindspore.flclient.model.LossCallback;
|
||||||
|
import com.mindspore.flclient.model.RunType;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defining the lenet client base class.
|
||||||
|
*
|
||||||
|
* @since v1.0
|
||||||
|
*/
|
||||||
|
public class LenetClient extends Client {
|
||||||
|
private static final Logger LOGGER = Logger.getLogger(LenetClient.class.toString());
|
||||||
|
private static final int NUM_OF_CLASS = 62;
|
||||||
|
|
||||||
|
static {
|
||||||
|
ClientManager.registerClient(new LenetClient());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Callback> initCallbacks(RunType runType, DataSet dataSet) {
|
||||||
|
List<Callback> callbacks = new ArrayList<>();
|
||||||
|
if (runType == RunType.TRAINMODE) {
|
||||||
|
Callback lossCallback = new LossCallback(trainSession);
|
||||||
|
callbacks.add(lossCallback);
|
||||||
|
} else if (runType == RunType.EVALMODE) {
|
||||||
|
if (dataSet instanceof LenetDataSet) {
|
||||||
|
Callback evalCallback = new ClassifierAccuracyCallback(trainSession, dataSet.batchSize, NUM_OF_CLASS,
|
||||||
|
((LenetDataSet) dataSet).getTargetLabels());
|
||||||
|
callbacks.add(evalCallback);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Callback inferCallback = new PredictCallback(trainSession, dataSet.batchSize, NUM_OF_CLASS);
|
||||||
|
callbacks.add(inferCallback);
|
||||||
|
}
|
||||||
|
return callbacks;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<RunType, Integer> initDataSets(Map<RunType, List<String>> files) {
|
||||||
|
Map<RunType, Integer> sampleCounts = new HashMap<>();
|
||||||
|
List<String> trainFiles = files.getOrDefault(RunType.TRAINMODE, null);
|
||||||
|
if (trainFiles != null) {
|
||||||
|
DataSet trainDataSet = new LenetDataSet(NUM_OF_CLASS);
|
||||||
|
trainDataSet.init(trainFiles);
|
||||||
|
dataSets.put(RunType.TRAINMODE, trainDataSet);
|
||||||
|
sampleCounts.put(RunType.TRAINMODE, trainDataSet.sampleSize);
|
||||||
|
}
|
||||||
|
List<String> evalFiles = files.getOrDefault(RunType.EVALMODE, null);
|
||||||
|
if (evalFiles != null) {
|
||||||
|
LenetDataSet evalDataSet = new LenetDataSet(NUM_OF_CLASS);
|
||||||
|
evalDataSet.init(evalFiles);
|
||||||
|
dataSets.put(RunType.EVALMODE, evalDataSet);
|
||||||
|
sampleCounts.put(RunType.EVALMODE, evalDataSet.sampleSize);
|
||||||
|
}
|
||||||
|
List<String> inferFiles = files.getOrDefault(RunType.INFERMODE, null);
|
||||||
|
if (inferFiles != null) {
|
||||||
|
DataSet inferDataSet = new LenetDataSet(NUM_OF_CLASS);
|
||||||
|
inferDataSet.init(inferFiles);
|
||||||
|
dataSets.put(RunType.INFERMODE, inferDataSet);
|
||||||
|
sampleCounts.put(RunType.INFERMODE, inferDataSet.sampleSize);
|
||||||
|
}
|
||||||
|
return sampleCounts;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float getEvalAccuracy(List<Callback> evalCallbacks) {
|
||||||
|
for (Callback callBack : evalCallbacks) {
|
||||||
|
if (callBack instanceof ClassifierAccuracyCallback) {
|
||||||
|
return ((ClassifierAccuracyCallback) callBack).getAccuracy();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOGGER.severe("don not find accuracy related callback");
|
||||||
|
return Float.NaN;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Integer> getInferResult(List<Callback> inferCallbacks) {
|
||||||
|
for (Callback callBack : inferCallbacks) {
|
||||||
|
if (callBack instanceof PredictCallback) {
|
||||||
|
return ((PredictCallback) callBack).getPredictResults();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOGGER.severe("don not find accuracy related callback");
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,201 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.mindspore.flclient.demo.lenet;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.model.DataSet;
|
||||||
|
import com.mindspore.flclient.model.Status;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.Paths;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defining the minist dataset for lenet.
|
||||||
|
*
|
||||||
|
* @since v1.0
|
||||||
|
*/
|
||||||
|
public class LenetDataSet extends DataSet {
|
||||||
|
private static final Logger LOGGER = Logger.getLogger(LenetDataSet.class.toString());
|
||||||
|
private static final int IMAGE_SIZE = 32 * 32 * 3;
|
||||||
|
private static final int FLOAT_BYTE_SIZE = 4;
|
||||||
|
private byte[] imageArray;
|
||||||
|
private int[] labelArray;
|
||||||
|
private final int numOfClass;
|
||||||
|
private List<Integer> targetLabels;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defining a constructor of lenet dataset.
|
||||||
|
*/
|
||||||
|
public LenetDataSet(int numOfClass) {
|
||||||
|
this.numOfClass = numOfClass;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get dataset labels.
|
||||||
|
*
|
||||||
|
* @return dataset target labels.
|
||||||
|
*/
|
||||||
|
public List<Integer> getTargetLabels() {
|
||||||
|
return targetLabels;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void fillInputBuffer(List<ByteBuffer> inputsBuffer, int batchIdx) {
|
||||||
|
// infer,train,eval model is same one
|
||||||
|
if (inputsBuffer.size() != 2) {
|
||||||
|
LOGGER.severe("input size error");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (batchIdx > batchNum) {
|
||||||
|
LOGGER.severe("fill model image input failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (ByteBuffer inputBuffer : inputsBuffer) {
|
||||||
|
inputBuffer.clear();
|
||||||
|
}
|
||||||
|
ByteBuffer imageBuffer = inputsBuffer.get(0);
|
||||||
|
ByteBuffer labelIdBuffer = inputsBuffer.get(1);
|
||||||
|
int imageInputBytes = IMAGE_SIZE * batchSize * Float.BYTES;
|
||||||
|
for (int i = 0; i < imageInputBytes; i++) {
|
||||||
|
imageBuffer.put(imageArray[batchIdx * imageInputBytes + i]);
|
||||||
|
}
|
||||||
|
if (labelArray == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int labelSize = batchSize * numOfClass;
|
||||||
|
if ((batchIdx + 1) * labelSize - 1 >= labelArray.length) {
|
||||||
|
LOGGER.severe("fill model label input failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
labelIdBuffer.clear();
|
||||||
|
for (int i = 0; i < labelSize; i++) {
|
||||||
|
labelIdBuffer.putFloat(labelArray[batchIdx * labelSize + i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void shuffle() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void padding() {
|
||||||
|
if (labelArray == null) // infer model
|
||||||
|
{
|
||||||
|
labelArray = new int[imageArray.length * numOfClass / (IMAGE_SIZE * Float.BYTES)];
|
||||||
|
Arrays.fill(labelArray, 0);
|
||||||
|
}
|
||||||
|
int curSize = labelArray.length / numOfClass;
|
||||||
|
int modSize = curSize - curSize / batchSize * batchSize;
|
||||||
|
int padSize = modSize != 0 ? batchSize * numOfClass - modSize : 0;
|
||||||
|
if (padSize != 0) {
|
||||||
|
int[] padLabelArray = new int[labelArray.length + padSize * numOfClass];
|
||||||
|
byte[] padImageArray = new byte[imageArray.length + padSize * IMAGE_SIZE * Float.BYTES];
|
||||||
|
System.arraycopy(labelArray, 0, padLabelArray, 0, labelArray.length);
|
||||||
|
System.arraycopy(imageArray, 0, padImageArray, 0, imageArray.length);
|
||||||
|
for (int i = 0; i < padSize; i++) {
|
||||||
|
int idx = (int) (Math.random() * curSize);
|
||||||
|
System.arraycopy(labelArray, idx * numOfClass, padLabelArray, labelArray.length + i * numOfClass,
|
||||||
|
numOfClass);
|
||||||
|
System.arraycopy(imageArray, idx * IMAGE_SIZE * Float.BYTES, padImageArray,
|
||||||
|
padImageArray.length + i * IMAGE_SIZE * Float.BYTES, IMAGE_SIZE * Float.BYTES);
|
||||||
|
}
|
||||||
|
labelArray = padLabelArray;
|
||||||
|
imageArray = padImageArray;
|
||||||
|
}
|
||||||
|
sampleSize = curSize + padSize;
|
||||||
|
batchNum = sampleSize / batchSize;
|
||||||
|
setPredictLabels(labelArray);
|
||||||
|
LOGGER.info("total samples:" + sampleSize);
|
||||||
|
LOGGER.info("total batchNum:" + batchNum);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setPredictLabels(int[] labelArray) {
|
||||||
|
int labels_num = labelArray.length / numOfClass;
|
||||||
|
targetLabels = new ArrayList<>(labels_num);
|
||||||
|
for (int i = 0; i < labels_num; i++) {
|
||||||
|
int label = getMaxIndex(labelArray, numOfClass * i, numOfClass * (i + 1));
|
||||||
|
if (label == -1) {
|
||||||
|
LOGGER.severe("get max index failed");
|
||||||
|
}
|
||||||
|
targetLabels.add(label);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private int getMaxIndex(int[] nums, int begin, int end) {
|
||||||
|
for (int i = begin; i < end; i++) {
|
||||||
|
if (nums[i] == 1) {
|
||||||
|
return i - begin;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static byte[] readBinFile(String dataFile) {
|
||||||
|
if (dataFile == null || dataFile.isEmpty()) {
|
||||||
|
LOGGER.severe("file cannot be empty");
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
|
// read train file
|
||||||
|
Path path = Paths.get(dataFile);
|
||||||
|
byte[] data = new byte[0];
|
||||||
|
try {
|
||||||
|
data = Files.readAllBytes(path);
|
||||||
|
} catch (IOException e) {
|
||||||
|
LOGGER.severe("read data file failed,please check data file path");
|
||||||
|
}
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Status dataPreprocess(List<String> files) {
|
||||||
|
String labelFile = "";
|
||||||
|
String imageFile;
|
||||||
|
if (files.size() == 2) {
|
||||||
|
imageFile = files.get(0);
|
||||||
|
labelFile = files.get(1);
|
||||||
|
} else if (files.size() == 1) {
|
||||||
|
imageFile = files.get(0);
|
||||||
|
} else {
|
||||||
|
LOGGER.severe("files size error");
|
||||||
|
return Status.FAILED;
|
||||||
|
}
|
||||||
|
imageArray = readBinFile(imageFile);
|
||||||
|
if (labelFile != null && !labelFile.isEmpty()) {
|
||||||
|
byte[] labelByteArray = readBinFile(labelFile);
|
||||||
|
targetLabels = new ArrayList<>(labelByteArray.length / FLOAT_BYTE_SIZE);
|
||||||
|
// model labels use one hot
|
||||||
|
labelArray = new int[labelByteArray.length / FLOAT_BYTE_SIZE * numOfClass];
|
||||||
|
Arrays.fill(labelArray, 0);
|
||||||
|
int offset = 0;
|
||||||
|
for (int i = 0; i < labelByteArray.length; i += FLOAT_BYTE_SIZE) {
|
||||||
|
labelArray[offset * numOfClass + labelByteArray[i]] = 1;
|
||||||
|
offset++;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
labelArray = null; // labelArray may be initialized from train
|
||||||
|
}
|
||||||
|
sampleSize = imageArray.length / IMAGE_SIZE / FLOAT_BYTE_SIZE;
|
||||||
|
return Status.SUCCESS;
|
||||||
|
}
|
||||||
|
}
|
|
@ -152,6 +152,10 @@ public class SessionUtil {
|
||||||
logger.severe(Common.addTag("trainSession cannot be null"));
|
logger.severe(Common.addTag("trainSession cannot be null"));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if(trainSession.getSessionPtr() == 0) {
|
||||||
|
logger.warning(Common.addTag("trainSession pointer has already free"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
trainSession.free();
|
trainSession.free();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue