!27780 [MS][LITE] add flclient albert example

Merge pull request !27780 from zhengjun10/fl2
This commit is contained in:
i-robot 2021-12-16 12:22:09 +00:00 committed by Gitee
commit 9558ba49d8
8 changed files with 562 additions and 409 deletions

View File

@ -14,17 +14,16 @@
* limitations under the License.
*/
package com.mindspore.flclient.example.lenet;
package com.mindspore.flclient.demo.albert;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.demo.common.ClassifierAccuracyCallback;
import com.mindspore.flclient.demo.common.PredictCallback;
import com.mindspore.flclient.model.Callback;
import com.mindspore.flclient.model.Client;
import com.mindspore.flclient.model.ClientManager;
import com.mindspore.flclient.model.CommonUtils;
import com.mindspore.flclient.model.DataSet;
import com.mindspore.flclient.model.LossCallback;
import com.mindspore.flclient.model.RunType;
import com.mindspore.lite.MSTensor;
import com.mindspore.lite.config.MSConfig;
import java.util.ArrayList;
@ -34,16 +33,17 @@ import java.util.Map;
import java.util.logging.Logger;
/**
* Defining the lenet client base class.
* Defining the albert client class.
*
* @since v1.0
*/
public class LenetClient extends Client {
private static final Logger logger = Logger.getLogger(LenetClient.class.toString());
private static final int NUM_OF_CLASS = 62;
public class AlbertClient extends Client {
private static final Logger LOGGER = Logger.getLogger(AlbertClient.class.toString());
private static final int NUM_OF_CLASS = 4;
private static final int MAX_SEQ_LEN = 8;
static {
ClientManager.registerClient(new LenetClient());
ClientManager.registerClient(new AlbertClient());
}
@Override
@ -53,9 +53,9 @@ public class LenetClient extends Client {
Callback lossCallback = new LossCallback(trainSession);
callbacks.add(lossCallback);
} else if (runType == RunType.EVALMODE) {
if (dataSet instanceof LenetDataSet) {
if (dataSet instanceof AlbertDataSet) {
Callback evalCallback = new ClassifierAccuracyCallback(trainSession, dataSet.batchSize, NUM_OF_CLASS,
((LenetDataSet) dataSet).getTargetLabels());
((AlbertDataSet) dataSet).getTargetLabels());
callbacks.add(evalCallback);
}
} else {
@ -70,21 +70,21 @@ public class LenetClient extends Client {
Map<RunType, Integer> sampleCounts = new HashMap<>();
List<String> trainFiles = files.getOrDefault(RunType.TRAINMODE, null);
if (trainFiles != null) {
DataSet trainDataSet = new LenetDataSet(NUM_OF_CLASS);
DataSet trainDataSet = new AlbertDataSet(RunType.TRAINMODE, MAX_SEQ_LEN);
trainDataSet.init(trainFiles);
dataSets.put(RunType.TRAINMODE, trainDataSet);
sampleCounts.put(RunType.TRAINMODE, trainDataSet.sampleSize);
}
List<String> evalFiles = files.getOrDefault(RunType.EVALMODE, null);
if (evalFiles != null) {
LenetDataSet evalDataSet = new LenetDataSet(NUM_OF_CLASS);
DataSet evalDataSet = new AlbertDataSet(RunType.EVALMODE, MAX_SEQ_LEN);
evalDataSet.init(evalFiles);
dataSets.put(RunType.EVALMODE, evalDataSet);
sampleCounts.put(RunType.EVALMODE, evalDataSet.sampleSize);
}
List<String> inferFiles = files.getOrDefault(RunType.INFERMODE, null);
if (inferFiles != null) {
DataSet inferDataSet = new LenetDataSet(NUM_OF_CLASS);
DataSet inferDataSet = new AlbertDataSet(RunType.INFERMODE, MAX_SEQ_LEN);
inferDataSet.init(inferFiles);
dataSets.put(RunType.INFERMODE, inferDataSet);
sampleCounts.put(RunType.INFERMODE, inferDataSet.sampleSize);
@ -99,7 +99,7 @@ public class LenetClient extends Client {
return ((ClassifierAccuracyCallback) callBack).getAccuracy();
}
}
logger.severe(Common.addTag("don not find accuracy related callback"));
LOGGER.severe("don not find accuracy related callback");
return Float.NaN;
}
@ -110,7 +110,7 @@ public class LenetClient extends Client {
return ((PredictCallback) callBack).getPredictResults();
}
}
logger.severe(Common.addTag("don not find accuracy related callback"));
LOGGER.severe("don not find accuracy related callback");
return new ArrayList<>();
}
}
}

View File

@ -0,0 +1,225 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.demo.albert;
import com.mindspore.flclient.model.CustomTokenizer;
import com.mindspore.flclient.model.DataSet;
import com.mindspore.flclient.model.Feature;
import com.mindspore.flclient.model.RunType;
import com.mindspore.flclient.model.Status;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.logging.Logger;
/**
* Defining dataset for albert.
*
* @since v1.0
*/
public class AlbertDataSet extends DataSet {
private static final Logger LOGGER = Logger.getLogger(AlbertDataSet.class.toString());
private static final int INPUT_FILE_NUM = 3;
private static final int IDS_FILE_INDEX = 2;
private static final int WORD_SPLIT_NUM = 2;
private static final int ALBERT_INPUT_SIZE = 4;
private static final int MASK_INPUT_INDEX = 2;
private static final int LABEL_INPUT_INDEX = 3;
private final RunType runType;
private final List<Feature> features = new ArrayList<>();
private final int maxSeqLen;
private final List<Integer> targetLabels = new ArrayList<>();
/**
* Defining a constructor of albert dataset.
*/
public AlbertDataSet(RunType runType, int maxSeqLen) {
this.runType = runType;
this.maxSeqLen = maxSeqLen;
}
/**
* Get dataset labels.
*
* @return dataset target labels.
*/
public List<Integer> getTargetLabels() {
return targetLabels;
}
@Override
public void fillInputBuffer(List<ByteBuffer> inputsBuffer, int batchIdx) {
// infer,train,eval model is same one
if (inputsBuffer.size() != ALBERT_INPUT_SIZE) {
LOGGER.severe("input size error");
return;
}
if (batchIdx > batchNum) {
LOGGER.severe("fill model image input failed");
return;
}
for (ByteBuffer inputBuffer : inputsBuffer) {
inputBuffer.clear();
}
ByteBuffer tokenIdBuffer = inputsBuffer.get(0);
ByteBuffer inputIdBufffer = inputsBuffer.get(1);
ByteBuffer maskIdBufffer = inputsBuffer.get(MASK_INPUT_INDEX);
ByteBuffer labelIdBufffer = inputsBuffer.get(LABEL_INPUT_INDEX);
for (int i = 0; i < batchSize; i++) {
Feature feature = features.get(batchIdx * batchSize + i);
for (int j = 0; j < maxSeqLen; j++) {
inputIdBufffer.putInt(feature.inputIds[j]);
tokenIdBuffer.putInt(feature.tokenIds[j]);
maskIdBufffer.putInt(feature.inputMasks[j]);
}
if (runType != RunType.INFERMODE) {
labelIdBufffer.putInt(feature.labelIds);
targetLabels.add(feature.labelIds);
}
}
}
@Override
public void shuffle() {
}
@Override
public void padding() {
if (batchSize <= 0) {
LOGGER.severe("batch size should bigger than 0");
return;
}
LOGGER.info("before pad samples size:" + features.size());
int curSize = features.size();
int modSize = curSize - curSize / batchSize * batchSize;
int padSize = modSize != 0 ? batchSize - modSize : 0;
for (int i = 0; i < padSize; i++) {
int idx = (int) (Math.random() * curSize);
features.add(features.get(idx));
}
sampleSize = features.size();
batchNum = features.size() / batchSize;
LOGGER.info("after pad samples size:" + features.size());
LOGGER.info("after pad batch num:" + batchNum);
}
private static List<String> readTxtFile(String file) {
if (file == null) {
LOGGER.severe("file cannot be empty");
return new ArrayList<>();
}
Path path = Paths.get(file);
List<String> allLines = new ArrayList<>();
try {
allLines = Files.readAllLines(path, StandardCharsets.UTF_8);
} catch (IOException e) {
LOGGER.severe("read txt file failed,please check txt file path");
}
return allLines;
}
private Status ConvertTrainData(String trainFile, String vocabFile, String idsFile) {
if (trainFile == null || vocabFile == null || idsFile == null) {
LOGGER.severe("dataset init failed,trainFile,idsFile,vocabFile cannot be empty");
return Status.NULLPTR;
}
// read train file
CustomTokenizer customTokenizer = new CustomTokenizer();
customTokenizer.init(vocabFile, idsFile, maxSeqLen);
List<String> allLines = readTxtFile(trainFile);
List<String> examples = new ArrayList<>();
List<String> labels = new ArrayList<>();
for (String line : allLines) {
String[] tokens = line.split(">>>");
if (tokens.length != WORD_SPLIT_NUM) {
LOGGER.warning("line may have format problem,need include >>>");
continue;
}
examples.add(tokens[1]);
tokens = tokens[0].split("<<<");
if (tokens.length != WORD_SPLIT_NUM) {
LOGGER.warning("line may have format problem,need include >>>");
continue;
}
labels.add(tokens[1]);
}
for (int i = 0; i < examples.size(); i++) {
List<Integer> tokens = customTokenizer.tokenize(examples.get(i), runType == RunType.TRAINMODE);
if (tokens.isEmpty()) {
continue;
}
Optional<Feature> feature = customTokenizer.getFeatures(tokens, labels.get(i));
if (!feature.isPresent()) {
continue;
}
if (runType == RunType.TRAINMODE) {
customTokenizer.addRandomMaskAndReplace(feature.get(), true, true);
}
feature.ifPresent(features::add);
}
sampleSize = features.size();
return Status.SUCCESS;
}
private Status ConvertInferData(String inferFile, String vocabFile, String idsFile) {
if (inferFile == null || vocabFile == null || idsFile == null) {
LOGGER.severe("dataset init failed,trainFile,idsFile,vocabFile cannot be empty");
return Status.NULLPTR;
}
// read train file
CustomTokenizer customTokenizer = new CustomTokenizer();
customTokenizer.init(vocabFile, idsFile, maxSeqLen);
List<String> allLines = readTxtFile(inferFile);
for (String line : allLines) {
if (line.isEmpty()) {
continue;
}
List<Integer> tokens = customTokenizer.tokenize(line, runType == RunType.TRAINMODE);
Optional<Feature> feature = customTokenizer.getFeatures(tokens, "other");
if (!feature.isPresent()) {
continue;
}
features.add(feature.get());
}
sampleSize = features.size();
return Status.SUCCESS;
}
@Override
public Status dataPreprocess(List<String> files) {
if (files.size() != INPUT_FILE_NUM) {
LOGGER.severe("files size error");
return Status.FAILED;
}
String dataFile = files.get(0);
String vocabFile = files.get(1);
String idsFile = files.get(IDS_FILE_INDEX);
if (runType == RunType.TRAINMODE || runType == RunType.EVALMODE) {
return ConvertTrainData(dataFile, vocabFile, idsFile);
} else {
return ConvertInferData(dataFile, vocabFile, idsFile);
}
}
}

View File

@ -0,0 +1,313 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.demo.albert;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Logger;
/**
* custom tokenizer class
*
* @since v1.0
*/
public class CustomTokenizer {
private static final Logger LOGGER = Logger.getLogger(CustomTokenizer.class.toString());
private static final char CHINESE_START_CODE = '\u4e00';
private static final char CHINESE_END_CODE = '\u9fa5';
private static final int VOCAB_SIZE = 11682;
private static final int RESERVED_LEN = 2;
private static final int FILL_NUM = 103;
private static final float LOW_THRESHOLD = 0.15f;
private static final float MID_THRESHOLD = 0.8f;
private static final float HIGH_THRESHOLD = 0.9f;
private final Map<String, Integer> vocabs = new HashMap<>();
private final int maxInputChars = 100;
private final String[] notSplitStrs = {"UNK"};
private int maxSeqLen = 8;
private Map<String, Integer> labelMap = new HashMap<String, Integer>() {
{
put("good", 0);
put("leimu", 1);
put("xiaoku", 2);
put("xin", 3);
put("other", 4);
}
};
private List<String> getPieceToken(String token) {
List<String> subTokens = new ArrayList<>();
boolean isBad = false;
int start = 0;
int tokenLen = token.length();
while (start < tokenLen) {
int end = tokenLen;
String curStr = "";
while (start < end) {
String subStr = token.substring(start, end);
if (start > 0) {
subStr = "##" + subStr;
}
if (vocabs.get(subStr) != null) {
curStr = subStr;
break;
}
end = end - 1;
}
if (curStr.isEmpty()) {
isBad = true;
break;
}
subTokens.add(curStr);
start = end;
}
if (isBad) {
return new ArrayList<>(Collections.singletonList("[UNK]"));
} else {
return subTokens;
}
}
/**
* init tokenizer
*
* @param vocabFile vocab file path
* @param idsFile word id file path
* @param seqLen max word len to clamp
*/
public void init(String vocabFile, String idsFile, int seqLen) {
if (vocabFile == null || idsFile == null) {
LOGGER.severe("idsFile,vocabFile cannot be empty");
return;
}
Path vocabPath = Paths.get(vocabFile);
List<String> vocabLines;
try {
vocabLines = Files.readAllLines(vocabPath, StandardCharsets.UTF_8);
} catch (IOException e) {
LOGGER.severe("read vocab file failed, please check vocab file path");
return;
}
Path idsPath = Paths.get(idsFile);
List<String> idsLines;
try {
idsLines = Files.readAllLines(idsPath, StandardCharsets.UTF_8);
} catch (IOException e) {
LOGGER.severe("read ids file failed, please check ids file path");
return;
}
for (int i = 0; i < idsLines.size(); ++i) {
try {
vocabs.put(vocabLines.get(i), Integer.parseInt(idsLines.get(i)));
} catch (NumberFormatException e) {
LOGGER.severe("id lines has invalid content");
return;
}
}
maxSeqLen = seqLen;
}
/**
* check char is chinese char or punc char
*
* @param trimChar input char
* @return ture if char is chinese char or punc char,else false
*/
public Boolean isChineseOrPunc(char trimChar) {
// is chinese char
if (trimChar >= CHINESE_START_CODE && trimChar <= CHINESE_END_CODE) {
return true;
}
// is puncuation char,according to the ascii table.
boolean isFrontPuncChar = (trimChar >= '!' && trimChar <= '/') || (trimChar >= ':' && trimChar <= '@');
boolean isBackPuncChar = (trimChar >= '[' && trimChar <= '`') || (trimChar >= '{' && trimChar <= '~');
return isFrontPuncChar || isBackPuncChar;
}
/**
* split text
*
* @param text input text
* @return split string array
*/
public String[] splitText(String text) {
if (text == null) {
return new String[0];
}
// clean remove white and control char
String trimText = text.trim();
StringBuilder cleanText = new StringBuilder();
for (int i = 0; i < trimText.length(); i++) {
if (isChineseOrPunc(trimText.charAt(i))) {
cleanText.append(" ").append(trimText.charAt(i)).append(" ");
} else {
cleanText.append(trimText.charAt(i));
}
}
return cleanText.toString().trim().split("\\s+");
}
/**
* combine token to piece
*
* @param tokens input tokens
* @return pieces
*/
public List<String> wordPieceTokenize(String[] tokens) {
if (tokens == null) {
return new ArrayList<>();
}
List<String> outputTokens = new ArrayList<>();
for (String token : tokens) {
List<String> subTokens = getPieceToken(token);
outputTokens.addAll(subTokens);
}
return outputTokens;
}
/**
* convert token to id
*
* @param tokens input tokens
* @param isCycTrunc if need cyc trunc
* @return ids
*/
public List<Integer> convertTokensToIds(List<String> tokens, boolean isCycTrunc) {
int seqLen = tokens.size();
List<String> truncTokens;
if (tokens.size() > maxSeqLen - RESERVED_LEN) {
if (isCycTrunc) {
int randIndex = (int) (Math.random() * seqLen);
if (randIndex > seqLen - maxSeqLen + RESERVED_LEN) {
List<String> rearPart = tokens.subList(randIndex, seqLen);
List<String> frontPart = tokens.subList(0, randIndex + maxSeqLen - RESERVED_LEN - seqLen);
rearPart.addAll(frontPart);
truncTokens = rearPart;
} else {
truncTokens = tokens.subList(randIndex, randIndex + maxSeqLen - RESERVED_LEN);
}
} else {
truncTokens = tokens.subList(0, maxSeqLen - RESERVED_LEN);
}
} else {
truncTokens = new ArrayList<>(tokens);
}
truncTokens.add(0, "[CLS]");
truncTokens.add("[SEP]");
List<Integer> ids = new ArrayList<>(truncTokens.size());
for (String token : truncTokens) {
ids.add(vocabs.getOrDefault(token, vocabs.get("[UNK]")));
}
return ids;
}
/**
* add random mask and replace feature
*
* @param feature text feature
* @param isKeptFirst if keep first char not change
* @param isKeptLast if keep last char not change
*/
public void addRandomMaskAndReplace(Feature feature, boolean isKeptFirst, boolean isKeptLast) {
if (feature == null) {
return;
}
int[] masks = new int[maxSeqLen];
Arrays.fill(masks, 1);
int[] replaces = new int[maxSeqLen];
Arrays.fill(replaces, 1);
int[] inputIds = feature.inputIds;
for (int i = 0; i < feature.seqLen; i++) {
double rand1 = Math.random();
if (rand1 < LOW_THRESHOLD) {
masks[i] = 0;
double rand2 = Math.random();
if (rand2 < MID_THRESHOLD) {
replaces[i] = FILL_NUM;
} else if (rand2 < HIGH_THRESHOLD) {
masks[i] = 1;
} else {
replaces[i] = (int) (Math.random() * VOCAB_SIZE);
}
}
if (isKeptFirst) {
masks[i] = 1;
replaces[i] = 0;
}
if (isKeptLast) {
masks[feature.seqLen - 1] = 1;
replaces[feature.seqLen - 1] = 0;
}
inputIds[i] = inputIds[i] * masks[i] + replaces[i];
}
}
/**
* get feature
*
* @param tokens input tokens
* @param label input label
* @return feature
*/
public Optional<Feature> getFeatures(List<Integer> tokens, String label) {
if (tokens == null || label == null) {
LOGGER.warning("tokens or label is null");
return Optional.empty();
}
if (!labelMap.containsKey(label)) {
return Optional.empty();
}
int[] segmentIds = new int[maxSeqLen];
Arrays.fill(segmentIds, 0);
int[] masks = new int[maxSeqLen];
Arrays.fill(masks, 0);
Arrays.fill(masks, 0, tokens.size(), 1); // tokens size can ensure less than masks
int[] inputIds = new int[maxSeqLen];
Arrays.fill(inputIds, 0);
for (int i = 0; i < tokens.size(); i++) {
inputIds[i] = tokens.get(i);
}
return Optional.of(new Feature(inputIds, masks, segmentIds, labelMap.get(label), tokens.size()));
}
/**
* tokenize text to tokens
*
* @param text input tokens
* @param isTrainMode if work in train mod
* @return tokens
*/
public List<Integer> tokenize(String text, boolean isTrainMode) {
if (text == null) {
LOGGER.warning("text is empty,skip it");
return new ArrayList<>();
}
String[] splitTokens = splitText(text);
List<String> wordPieceTokens = wordPieceTokenize(splitTokens);
return convertTokensToIds(wordPieceTokens, isTrainMode); // trainMod need cyclicTrunc
}
}

View File

@ -1,111 +0,0 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.example.lenet;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.model.Callback;
import com.mindspore.flclient.model.CommonUtils;
import com.mindspore.flclient.model.Status;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.MSTensor;
import java.util.List;
import java.util.Optional;
import java.util.logging.Logger;
/**
* Defining the Callback calculate classifier model.
*
* @since v1.0
*/
public class ClassifierAccuracyCallback extends Callback {
private static final Logger logger = Logger.getLogger(ClassifierAccuracyCallback.class.toString());
private final int numOfClass;
private final int batchSize;
private final List<Integer> targetLabels;
private float accuracy;
/**
* Defining a constructor of ClassifierAccuracyCallback.
*/
public ClassifierAccuracyCallback(LiteSession session, int batchSize, int numOfClass, List<Integer> targetLabels) {
super(session);
this.batchSize = batchSize;
this.numOfClass = numOfClass;
this.targetLabels = targetLabels;
}
/**
* Get eval accuracy.
*
* @return accuracy.
*/
public float getAccuracy() {
return accuracy;
}
@Override
public Status stepBegin() {
return Status.SUCCESS;
}
@Override
public Status stepEnd() {
Status status = calAccuracy();
if (status != Status.SUCCESS) {
return status;
}
steps++;
return Status.SUCCESS;
}
@Override
public Status epochBegin() {
return Status.SUCCESS;
}
@Override
public Status epochEnd() {
logger.info(Common.addTag("average accuracy:" + steps + ",acc is:" + accuracy / steps));
accuracy = accuracy / steps;
steps = 0;
return Status.SUCCESS;
}
private Status calAccuracy() {
if (targetLabels == null || targetLabels.isEmpty()) {
logger.severe(Common.addTag("labels cannot be null"));
return Status.NULLPTR;
}
Optional<MSTensor> outputTensor = searchOutputsForSize(batchSize * numOfClass);
if (!outputTensor.isPresent()) {
return Status.NULLPTR;
}
float[] scores = outputTensor.get().getFloatData();
int hitCounts = 0;
for (int b = 0; b < batchSize; b++) {
int predictIdx = CommonUtils.getMaxScoreIndex(scores, numOfClass * b, numOfClass * b + numOfClass);
if (targetLabels.get(b + steps * batchSize) == predictIdx) {
hitCounts += 1;
}
}
accuracy += ((float) (hitCounts) / batchSize);
logger.info(Common.addTag("steps:" + steps + ",acc is:" + (float) (hitCounts) / batchSize));
return Status.SUCCESS;
}
}

View File

@ -1,188 +0,0 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.example.lenet;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.model.DataSet;
import com.mindspore.flclient.model.FileUtil;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;
/**
* Defining the minist dataset for lenet.
*
* @since v1.0
*/
public class LenetDataSet extends DataSet {
private static final Logger logger = Logger.getLogger(DataSet.class.toString());
private static final int IMAGE_SIZE = 32 * 32 * 3;
private static final int FLOAT_BYTE_SIZE = 4;
private byte[] imageArray;
private int[] labelArray;
private final int numOfClass;
private List<Integer> targetLabels;
/**
* Defining a constructor of lenet dataset.
*/
public LenetDataSet(int numOfClass) {
this.numOfClass = numOfClass;
}
/**
* Get dataset labels.
*
* @return dataset target labels.
*/
public List<Integer> getTargetLabels() {
return targetLabels;
}
@Override
public void fillInputBuffer(List<ByteBuffer> inputsBuffer, int batchIdx) {
// infer,train,eval model is same one
if (inputsBuffer.size() != 2) {
logger.severe(Common.addTag("input size error"));
return;
}
if (batchIdx > batchNum) {
logger.severe(Common.addTag("fill model image input failed"));
return;
}
for (ByteBuffer inputBuffer : inputsBuffer) {
inputBuffer.clear();
}
ByteBuffer imageBuffer = inputsBuffer.get(0);
ByteBuffer labelIdBuffer = inputsBuffer.get(1);
int imageInputBytes = IMAGE_SIZE * batchSize * Float.BYTES;
for (int i = 0; i < imageInputBytes; i++) {
imageBuffer.put(imageArray[batchIdx * imageInputBytes + i]);
}
if (labelArray == null) {
return;
}
int labelSize = batchSize * numOfClass;
if ((batchIdx + 1) * labelSize - 1 >= labelArray.length) {
logger.severe(Common.addTag("fill model label input failed"));
return;
}
labelIdBuffer.clear();
for (int i = 0; i < labelSize; i++) {
labelIdBuffer.putFloat(labelArray[batchIdx * labelSize + i]);
}
}
@Override
public void shuffle() {
}
@Override
public void padding() {
if (labelArray == null) // infer model
{
labelArray = new int[imageArray.length * numOfClass / (IMAGE_SIZE * Float.BYTES)];
Arrays.fill(labelArray, 0);
}
int curSize = labelArray.length / numOfClass;
int modSize = curSize - curSize / batchSize * batchSize;
int padSize = modSize != 0 ? batchSize * numOfClass - modSize : 0;
if (padSize != 0) {
int[] padLabelArray = new int[labelArray.length + padSize * numOfClass];
byte[] padImageArray = new byte[imageArray.length + padSize * IMAGE_SIZE * Float.BYTES];
System.arraycopy(labelArray, 0, padLabelArray, 0, labelArray.length);
System.arraycopy(imageArray, 0, padImageArray, 0, imageArray.length);
for (int i = 0; i < padSize; i++) {
int idx = (int) (Math.random() * curSize);
System.arraycopy(labelArray, idx * numOfClass, padLabelArray, labelArray.length + i * numOfClass,
numOfClass);
System.arraycopy(imageArray, idx * IMAGE_SIZE * Float.BYTES, padImageArray,
padImageArray.length + i * IMAGE_SIZE * Float.BYTES, IMAGE_SIZE * Float.BYTES);
}
labelArray = padLabelArray;
imageArray = padImageArray;
}
sampleSize = curSize + padSize;
batchNum = sampleSize / batchSize;
setPredictLabels(labelArray);
logger.info(Common.addTag("total samples:" + sampleSize));
logger.info(Common.addTag("total batchNum:" + batchNum));
}
private void setPredictLabels(int[] labelArray) {
int labels_num = labelArray.length / numOfClass;
targetLabels = new ArrayList<>(labels_num);
for (int i = 0; i < labels_num; i++) {
int label = getMaxIndex(labelArray, numOfClass * i, numOfClass * (i + 1));
if (label == -1) {
logger.severe(Common.addTag("get max index failed"));
}
targetLabels.add(label);
}
}
private int getMaxIndex(int[] nums, int begin, int end) {
for (int i = begin; i < end; i++) {
if (nums[i] == 1) {
return i - begin;
}
}
return -1;
}
@Override
public int dataPreprocess(List<String> files) {
String labelFile = "";
String imageFile;
if (files.size() == 2) {
imageFile = files.get(0);
labelFile = files.get(1);
} else if (files.size() == 1) {
imageFile = files.get(0);
} else {
logger.severe(Common.addTag("files size error"));
return -1;
}
imageArray = FileUtil.readBinFile(imageFile);
if (labelFile != null && !labelFile.isEmpty()) {
byte[] labelByteArray = FileUtil.readBinFile(labelFile);
targetLabels = new ArrayList<>(labelByteArray.length / FLOAT_BYTE_SIZE);
// model labels use one hot
labelArray = new int[labelByteArray.length / FLOAT_BYTE_SIZE * numOfClass];
Arrays.fill(labelArray, 0);
int offset = 0;
for (int i = 0; i < labelByteArray.length; i += FLOAT_BYTE_SIZE) {
labelArray[offset * numOfClass + labelByteArray[i]] = 1;
offset++;
}
} else {
labelArray = null; // labelArray may be initialized from train
}
sampleSize = imageArray.length/IMAGE_SIZE/FLOAT_BYTE_SIZE;
return sampleSize;
}
}

View File

@ -1,85 +0,0 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.example.lenet;
import com.mindspore.flclient.model.Callback;
import com.mindspore.flclient.model.CommonUtils;
import com.mindspore.flclient.model.Status;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.MSTensor;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
/**
* Defining the Callback get model predict result.
*
* @since v1.0
*/
public class PredictCallback extends Callback {
private final List<Integer> predictResults = new ArrayList<>();
private final int numOfClass;
private final int batchSize;
/**
* Defining a constructor of predict callback.
*/
public PredictCallback(LiteSession session, int batchSize, int numOfClass) {
super(session);
this.batchSize = batchSize;
this.numOfClass = numOfClass;
}
/**
* Get predict results.
*
* @return predict result.
*/
public List<Integer> getPredictResults() {
return predictResults;
}
@Override
public Status stepBegin() {
return Status.SUCCESS;
}
@Override
public Status stepEnd() {
Optional<MSTensor> outputTensor = searchOutputsForSize(batchSize * numOfClass);
if (!outputTensor.isPresent()) {
return Status.FAILED;
}
float[] scores = outputTensor.get().getFloatData();
for (int b = 0; b < batchSize; b++) {
int predictIdx = CommonUtils.getMaxScoreIndex(scores, numOfClass * b, numOfClass * b + numOfClass);
predictResults.add(predictIdx);
}
return Status.SUCCESS;
}
@Override
public Status epochBegin() {
return Status.SUCCESS;
}
@Override
public Status epochEnd() {
return Status.SUCCESS;
}
}

View File

@ -17,7 +17,6 @@
package com.mindspore.flclient.model;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.example.lenet.ClassifierAccuracyCallback;
import java.util.HashMap;
import java.util.logging.Logger;
@ -25,7 +24,7 @@ import java.util.logging.Logger;
public class ClientManager {
private static final HashMap<String, Client> clientMaps = new HashMap<>();
private static final Logger logger = Logger.getLogger(ClassifierAccuracyCallback.class.toString());
private static final Logger logger = Logger.getLogger(ClientManager.class.toString());
/**
* Register client.

View File

@ -69,21 +69,21 @@ public abstract class DataSet {
* @param files data files.
* @return preprocess status.
*/
public abstract int dataPreprocess(List<String> files);
public abstract Status dataPreprocess(List<String> files);
/**
* Init dataset.
*
* @param files data files.
* @return dataset size.
* @return init status.
*/
public int init(List<String> files) {
int status = dataPreprocess(files);
if (status != 0) {
public Status init(List<String> files) {
Status status = dataPreprocess(files);
if (status != Status.SUCCESS) {
logger.severe(Common.addTag("data preprocess failed"));
return status;
}
shuffle();
return sampleSize;
return status;
}
}