forked from mindspore-Ecosystem/mindspore
!27780 [MS][LITE] add flclient albert example
Merge pull request !27780 from zhengjun10/fl2
This commit is contained in:
commit
9558ba49d8
|
@ -14,17 +14,16 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.example.lenet;
|
||||
package com.mindspore.flclient.demo.albert;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.demo.common.ClassifierAccuracyCallback;
|
||||
import com.mindspore.flclient.demo.common.PredictCallback;
|
||||
import com.mindspore.flclient.model.Callback;
|
||||
import com.mindspore.flclient.model.Client;
|
||||
import com.mindspore.flclient.model.ClientManager;
|
||||
import com.mindspore.flclient.model.CommonUtils;
|
||||
import com.mindspore.flclient.model.DataSet;
|
||||
import com.mindspore.flclient.model.LossCallback;
|
||||
import com.mindspore.flclient.model.RunType;
|
||||
import com.mindspore.lite.MSTensor;
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -34,16 +33,17 @@ import java.util.Map;
|
|||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Defining the lenet client base class.
|
||||
* Defining the albert client class.
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class LenetClient extends Client {
|
||||
private static final Logger logger = Logger.getLogger(LenetClient.class.toString());
|
||||
private static final int NUM_OF_CLASS = 62;
|
||||
public class AlbertClient extends Client {
|
||||
private static final Logger LOGGER = Logger.getLogger(AlbertClient.class.toString());
|
||||
private static final int NUM_OF_CLASS = 4;
|
||||
private static final int MAX_SEQ_LEN = 8;
|
||||
|
||||
static {
|
||||
ClientManager.registerClient(new LenetClient());
|
||||
ClientManager.registerClient(new AlbertClient());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -53,9 +53,9 @@ public class LenetClient extends Client {
|
|||
Callback lossCallback = new LossCallback(trainSession);
|
||||
callbacks.add(lossCallback);
|
||||
} else if (runType == RunType.EVALMODE) {
|
||||
if (dataSet instanceof LenetDataSet) {
|
||||
if (dataSet instanceof AlbertDataSet) {
|
||||
Callback evalCallback = new ClassifierAccuracyCallback(trainSession, dataSet.batchSize, NUM_OF_CLASS,
|
||||
((LenetDataSet) dataSet).getTargetLabels());
|
||||
((AlbertDataSet) dataSet).getTargetLabels());
|
||||
callbacks.add(evalCallback);
|
||||
}
|
||||
} else {
|
||||
|
@ -70,21 +70,21 @@ public class LenetClient extends Client {
|
|||
Map<RunType, Integer> sampleCounts = new HashMap<>();
|
||||
List<String> trainFiles = files.getOrDefault(RunType.TRAINMODE, null);
|
||||
if (trainFiles != null) {
|
||||
DataSet trainDataSet = new LenetDataSet(NUM_OF_CLASS);
|
||||
DataSet trainDataSet = new AlbertDataSet(RunType.TRAINMODE, MAX_SEQ_LEN);
|
||||
trainDataSet.init(trainFiles);
|
||||
dataSets.put(RunType.TRAINMODE, trainDataSet);
|
||||
sampleCounts.put(RunType.TRAINMODE, trainDataSet.sampleSize);
|
||||
}
|
||||
List<String> evalFiles = files.getOrDefault(RunType.EVALMODE, null);
|
||||
if (evalFiles != null) {
|
||||
LenetDataSet evalDataSet = new LenetDataSet(NUM_OF_CLASS);
|
||||
DataSet evalDataSet = new AlbertDataSet(RunType.EVALMODE, MAX_SEQ_LEN);
|
||||
evalDataSet.init(evalFiles);
|
||||
dataSets.put(RunType.EVALMODE, evalDataSet);
|
||||
sampleCounts.put(RunType.EVALMODE, evalDataSet.sampleSize);
|
||||
}
|
||||
List<String> inferFiles = files.getOrDefault(RunType.INFERMODE, null);
|
||||
if (inferFiles != null) {
|
||||
DataSet inferDataSet = new LenetDataSet(NUM_OF_CLASS);
|
||||
DataSet inferDataSet = new AlbertDataSet(RunType.INFERMODE, MAX_SEQ_LEN);
|
||||
inferDataSet.init(inferFiles);
|
||||
dataSets.put(RunType.INFERMODE, inferDataSet);
|
||||
sampleCounts.put(RunType.INFERMODE, inferDataSet.sampleSize);
|
||||
|
@ -99,7 +99,7 @@ public class LenetClient extends Client {
|
|||
return ((ClassifierAccuracyCallback) callBack).getAccuracy();
|
||||
}
|
||||
}
|
||||
logger.severe(Common.addTag("don not find accuracy related callback"));
|
||||
LOGGER.severe("don not find accuracy related callback");
|
||||
return Float.NaN;
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,7 @@ public class LenetClient extends Client {
|
|||
return ((PredictCallback) callBack).getPredictResults();
|
||||
}
|
||||
}
|
||||
logger.severe(Common.addTag("don not find accuracy related callback"));
|
||||
LOGGER.severe("don not find accuracy related callback");
|
||||
return new ArrayList<>();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,225 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.demo.albert;
|
||||
|
||||
import com.mindspore.flclient.model.CustomTokenizer;
|
||||
import com.mindspore.flclient.model.DataSet;
|
||||
import com.mindspore.flclient.model.Feature;
|
||||
import com.mindspore.flclient.model.RunType;
|
||||
import com.mindspore.flclient.model.Status;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Defining dataset for albert.
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class AlbertDataSet extends DataSet {
|
||||
private static final Logger LOGGER = Logger.getLogger(AlbertDataSet.class.toString());
|
||||
private static final int INPUT_FILE_NUM = 3;
|
||||
private static final int IDS_FILE_INDEX = 2;
|
||||
private static final int WORD_SPLIT_NUM = 2;
|
||||
private static final int ALBERT_INPUT_SIZE = 4;
|
||||
private static final int MASK_INPUT_INDEX = 2;
|
||||
private static final int LABEL_INPUT_INDEX = 3;
|
||||
private final RunType runType;
|
||||
private final List<Feature> features = new ArrayList<>();
|
||||
private final int maxSeqLen;
|
||||
private final List<Integer> targetLabels = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* Defining a constructor of albert dataset.
|
||||
*/
|
||||
public AlbertDataSet(RunType runType, int maxSeqLen) {
|
||||
this.runType = runType;
|
||||
this.maxSeqLen = maxSeqLen;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get dataset labels.
|
||||
*
|
||||
* @return dataset target labels.
|
||||
*/
|
||||
public List<Integer> getTargetLabels() {
|
||||
return targetLabels;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillInputBuffer(List<ByteBuffer> inputsBuffer, int batchIdx) {
|
||||
// infer,train,eval model is same one
|
||||
if (inputsBuffer.size() != ALBERT_INPUT_SIZE) {
|
||||
LOGGER.severe("input size error");
|
||||
return;
|
||||
}
|
||||
if (batchIdx > batchNum) {
|
||||
LOGGER.severe("fill model image input failed");
|
||||
return;
|
||||
}
|
||||
for (ByteBuffer inputBuffer : inputsBuffer) {
|
||||
inputBuffer.clear();
|
||||
}
|
||||
ByteBuffer tokenIdBuffer = inputsBuffer.get(0);
|
||||
ByteBuffer inputIdBufffer = inputsBuffer.get(1);
|
||||
ByteBuffer maskIdBufffer = inputsBuffer.get(MASK_INPUT_INDEX);
|
||||
ByteBuffer labelIdBufffer = inputsBuffer.get(LABEL_INPUT_INDEX);
|
||||
for (int i = 0; i < batchSize; i++) {
|
||||
Feature feature = features.get(batchIdx * batchSize + i);
|
||||
for (int j = 0; j < maxSeqLen; j++) {
|
||||
inputIdBufffer.putInt(feature.inputIds[j]);
|
||||
tokenIdBuffer.putInt(feature.tokenIds[j]);
|
||||
maskIdBufffer.putInt(feature.inputMasks[j]);
|
||||
}
|
||||
if (runType != RunType.INFERMODE) {
|
||||
labelIdBufffer.putInt(feature.labelIds);
|
||||
targetLabels.add(feature.labelIds);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void shuffle() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void padding() {
|
||||
if (batchSize <= 0) {
|
||||
LOGGER.severe("batch size should bigger than 0");
|
||||
return;
|
||||
}
|
||||
LOGGER.info("before pad samples size:" + features.size());
|
||||
int curSize = features.size();
|
||||
int modSize = curSize - curSize / batchSize * batchSize;
|
||||
int padSize = modSize != 0 ? batchSize - modSize : 0;
|
||||
for (int i = 0; i < padSize; i++) {
|
||||
int idx = (int) (Math.random() * curSize);
|
||||
features.add(features.get(idx));
|
||||
}
|
||||
sampleSize = features.size();
|
||||
batchNum = features.size() / batchSize;
|
||||
LOGGER.info("after pad samples size:" + features.size());
|
||||
LOGGER.info("after pad batch num:" + batchNum);
|
||||
}
|
||||
|
||||
private static List<String> readTxtFile(String file) {
|
||||
if (file == null) {
|
||||
LOGGER.severe("file cannot be empty");
|
||||
return new ArrayList<>();
|
||||
}
|
||||
Path path = Paths.get(file);
|
||||
List<String> allLines = new ArrayList<>();
|
||||
try {
|
||||
allLines = Files.readAllLines(path, StandardCharsets.UTF_8);
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe("read txt file failed,please check txt file path");
|
||||
}
|
||||
return allLines;
|
||||
}
|
||||
|
||||
private Status ConvertTrainData(String trainFile, String vocabFile, String idsFile) {
|
||||
if (trainFile == null || vocabFile == null || idsFile == null) {
|
||||
LOGGER.severe("dataset init failed,trainFile,idsFile,vocabFile cannot be empty");
|
||||
return Status.NULLPTR;
|
||||
}
|
||||
// read train file
|
||||
CustomTokenizer customTokenizer = new CustomTokenizer();
|
||||
customTokenizer.init(vocabFile, idsFile, maxSeqLen);
|
||||
List<String> allLines = readTxtFile(trainFile);
|
||||
List<String> examples = new ArrayList<>();
|
||||
List<String> labels = new ArrayList<>();
|
||||
for (String line : allLines) {
|
||||
String[] tokens = line.split(">>>");
|
||||
if (tokens.length != WORD_SPLIT_NUM) {
|
||||
LOGGER.warning("line may have format problem,need include >>>");
|
||||
continue;
|
||||
}
|
||||
examples.add(tokens[1]);
|
||||
tokens = tokens[0].split("<<<");
|
||||
if (tokens.length != WORD_SPLIT_NUM) {
|
||||
LOGGER.warning("line may have format problem,need include >>>");
|
||||
continue;
|
||||
}
|
||||
labels.add(tokens[1]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < examples.size(); i++) {
|
||||
List<Integer> tokens = customTokenizer.tokenize(examples.get(i), runType == RunType.TRAINMODE);
|
||||
if (tokens.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
Optional<Feature> feature = customTokenizer.getFeatures(tokens, labels.get(i));
|
||||
if (!feature.isPresent()) {
|
||||
continue;
|
||||
}
|
||||
if (runType == RunType.TRAINMODE) {
|
||||
customTokenizer.addRandomMaskAndReplace(feature.get(), true, true);
|
||||
}
|
||||
feature.ifPresent(features::add);
|
||||
}
|
||||
sampleSize = features.size();
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
private Status ConvertInferData(String inferFile, String vocabFile, String idsFile) {
|
||||
if (inferFile == null || vocabFile == null || idsFile == null) {
|
||||
LOGGER.severe("dataset init failed,trainFile,idsFile,vocabFile cannot be empty");
|
||||
return Status.NULLPTR;
|
||||
}
|
||||
// read train file
|
||||
CustomTokenizer customTokenizer = new CustomTokenizer();
|
||||
customTokenizer.init(vocabFile, idsFile, maxSeqLen);
|
||||
List<String> allLines = readTxtFile(inferFile);
|
||||
for (String line : allLines) {
|
||||
if (line.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
List<Integer> tokens = customTokenizer.tokenize(line, runType == RunType.TRAINMODE);
|
||||
Optional<Feature> feature = customTokenizer.getFeatures(tokens, "other");
|
||||
if (!feature.isPresent()) {
|
||||
continue;
|
||||
}
|
||||
features.add(feature.get());
|
||||
}
|
||||
sampleSize = features.size();
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status dataPreprocess(List<String> files) {
|
||||
if (files.size() != INPUT_FILE_NUM) {
|
||||
LOGGER.severe("files size error");
|
||||
return Status.FAILED;
|
||||
}
|
||||
String dataFile = files.get(0);
|
||||
String vocabFile = files.get(1);
|
||||
String idsFile = files.get(IDS_FILE_INDEX);
|
||||
if (runType == RunType.TRAINMODE || runType == RunType.EVALMODE) {
|
||||
return ConvertTrainData(dataFile, vocabFile, idsFile);
|
||||
} else {
|
||||
return ConvertInferData(dataFile, vocabFile, idsFile);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,313 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.demo.albert;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* custom tokenizer class
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class CustomTokenizer {
|
||||
private static final Logger LOGGER = Logger.getLogger(CustomTokenizer.class.toString());
|
||||
private static final char CHINESE_START_CODE = '\u4e00';
|
||||
private static final char CHINESE_END_CODE = '\u9fa5';
|
||||
private static final int VOCAB_SIZE = 11682;
|
||||
private static final int RESERVED_LEN = 2;
|
||||
private static final int FILL_NUM = 103;
|
||||
private static final float LOW_THRESHOLD = 0.15f;
|
||||
private static final float MID_THRESHOLD = 0.8f;
|
||||
private static final float HIGH_THRESHOLD = 0.9f;
|
||||
private final Map<String, Integer> vocabs = new HashMap<>();
|
||||
private final int maxInputChars = 100;
|
||||
private final String[] notSplitStrs = {"UNK"};
|
||||
private int maxSeqLen = 8;
|
||||
private Map<String, Integer> labelMap = new HashMap<String, Integer>() {
|
||||
{
|
||||
put("good", 0);
|
||||
put("leimu", 1);
|
||||
put("xiaoku", 2);
|
||||
put("xin", 3);
|
||||
put("other", 4);
|
||||
}
|
||||
};
|
||||
|
||||
private List<String> getPieceToken(String token) {
|
||||
List<String> subTokens = new ArrayList<>();
|
||||
boolean isBad = false;
|
||||
int start = 0;
|
||||
int tokenLen = token.length();
|
||||
while (start < tokenLen) {
|
||||
int end = tokenLen;
|
||||
String curStr = "";
|
||||
while (start < end) {
|
||||
String subStr = token.substring(start, end);
|
||||
if (start > 0) {
|
||||
subStr = "##" + subStr;
|
||||
}
|
||||
if (vocabs.get(subStr) != null) {
|
||||
curStr = subStr;
|
||||
break;
|
||||
}
|
||||
end = end - 1;
|
||||
}
|
||||
if (curStr.isEmpty()) {
|
||||
isBad = true;
|
||||
break;
|
||||
}
|
||||
subTokens.add(curStr);
|
||||
start = end;
|
||||
}
|
||||
if (isBad) {
|
||||
return new ArrayList<>(Collections.singletonList("[UNK]"));
|
||||
} else {
|
||||
return subTokens;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* init tokenizer
|
||||
*
|
||||
* @param vocabFile vocab file path
|
||||
* @param idsFile word id file path
|
||||
* @param seqLen max word len to clamp
|
||||
*/
|
||||
public void init(String vocabFile, String idsFile, int seqLen) {
|
||||
if (vocabFile == null || idsFile == null) {
|
||||
LOGGER.severe("idsFile,vocabFile cannot be empty");
|
||||
return;
|
||||
}
|
||||
Path vocabPath = Paths.get(vocabFile);
|
||||
List<String> vocabLines;
|
||||
try {
|
||||
vocabLines = Files.readAllLines(vocabPath, StandardCharsets.UTF_8);
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe("read vocab file failed, please check vocab file path");
|
||||
return;
|
||||
}
|
||||
Path idsPath = Paths.get(idsFile);
|
||||
List<String> idsLines;
|
||||
try {
|
||||
idsLines = Files.readAllLines(idsPath, StandardCharsets.UTF_8);
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe("read ids file failed, please check ids file path");
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < idsLines.size(); ++i) {
|
||||
try {
|
||||
vocabs.put(vocabLines.get(i), Integer.parseInt(idsLines.get(i)));
|
||||
} catch (NumberFormatException e) {
|
||||
LOGGER.severe("id lines has invalid content");
|
||||
return;
|
||||
}
|
||||
}
|
||||
maxSeqLen = seqLen;
|
||||
}
|
||||
|
||||
/**
|
||||
* check char is chinese char or punc char
|
||||
*
|
||||
* @param trimChar input char
|
||||
* @return ture if char is chinese char or punc char,else false
|
||||
*/
|
||||
public Boolean isChineseOrPunc(char trimChar) {
|
||||
// is chinese char
|
||||
if (trimChar >= CHINESE_START_CODE && trimChar <= CHINESE_END_CODE) {
|
||||
return true;
|
||||
}
|
||||
// is puncuation char,according to the ascii table.
|
||||
boolean isFrontPuncChar = (trimChar >= '!' && trimChar <= '/') || (trimChar >= ':' && trimChar <= '@');
|
||||
boolean isBackPuncChar = (trimChar >= '[' && trimChar <= '`') || (trimChar >= '{' && trimChar <= '~');
|
||||
return isFrontPuncChar || isBackPuncChar;
|
||||
}
|
||||
|
||||
/**
|
||||
* split text
|
||||
*
|
||||
* @param text input text
|
||||
* @return split string array
|
||||
*/
|
||||
public String[] splitText(String text) {
|
||||
if (text == null) {
|
||||
return new String[0];
|
||||
}
|
||||
// clean remove white and control char
|
||||
String trimText = text.trim();
|
||||
StringBuilder cleanText = new StringBuilder();
|
||||
for (int i = 0; i < trimText.length(); i++) {
|
||||
if (isChineseOrPunc(trimText.charAt(i))) {
|
||||
cleanText.append(" ").append(trimText.charAt(i)).append(" ");
|
||||
} else {
|
||||
cleanText.append(trimText.charAt(i));
|
||||
}
|
||||
}
|
||||
return cleanText.toString().trim().split("\\s+");
|
||||
}
|
||||
|
||||
/**
|
||||
* combine token to piece
|
||||
*
|
||||
* @param tokens input tokens
|
||||
* @return pieces
|
||||
*/
|
||||
public List<String> wordPieceTokenize(String[] tokens) {
|
||||
if (tokens == null) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
List<String> outputTokens = new ArrayList<>();
|
||||
for (String token : tokens) {
|
||||
List<String> subTokens = getPieceToken(token);
|
||||
outputTokens.addAll(subTokens);
|
||||
}
|
||||
return outputTokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* convert token to id
|
||||
*
|
||||
* @param tokens input tokens
|
||||
* @param isCycTrunc if need cyc trunc
|
||||
* @return ids
|
||||
*/
|
||||
public List<Integer> convertTokensToIds(List<String> tokens, boolean isCycTrunc) {
|
||||
int seqLen = tokens.size();
|
||||
List<String> truncTokens;
|
||||
if (tokens.size() > maxSeqLen - RESERVED_LEN) {
|
||||
if (isCycTrunc) {
|
||||
int randIndex = (int) (Math.random() * seqLen);
|
||||
if (randIndex > seqLen - maxSeqLen + RESERVED_LEN) {
|
||||
List<String> rearPart = tokens.subList(randIndex, seqLen);
|
||||
List<String> frontPart = tokens.subList(0, randIndex + maxSeqLen - RESERVED_LEN - seqLen);
|
||||
rearPart.addAll(frontPart);
|
||||
truncTokens = rearPart;
|
||||
} else {
|
||||
truncTokens = tokens.subList(randIndex, randIndex + maxSeqLen - RESERVED_LEN);
|
||||
}
|
||||
} else {
|
||||
truncTokens = tokens.subList(0, maxSeqLen - RESERVED_LEN);
|
||||
}
|
||||
} else {
|
||||
truncTokens = new ArrayList<>(tokens);
|
||||
}
|
||||
truncTokens.add(0, "[CLS]");
|
||||
truncTokens.add("[SEP]");
|
||||
List<Integer> ids = new ArrayList<>(truncTokens.size());
|
||||
for (String token : truncTokens) {
|
||||
ids.add(vocabs.getOrDefault(token, vocabs.get("[UNK]")));
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
/**
|
||||
* add random mask and replace feature
|
||||
*
|
||||
* @param feature text feature
|
||||
* @param isKeptFirst if keep first char not change
|
||||
* @param isKeptLast if keep last char not change
|
||||
*/
|
||||
public void addRandomMaskAndReplace(Feature feature, boolean isKeptFirst, boolean isKeptLast) {
|
||||
if (feature == null) {
|
||||
return;
|
||||
}
|
||||
int[] masks = new int[maxSeqLen];
|
||||
Arrays.fill(masks, 1);
|
||||
int[] replaces = new int[maxSeqLen];
|
||||
Arrays.fill(replaces, 1);
|
||||
int[] inputIds = feature.inputIds;
|
||||
for (int i = 0; i < feature.seqLen; i++) {
|
||||
double rand1 = Math.random();
|
||||
if (rand1 < LOW_THRESHOLD) {
|
||||
masks[i] = 0;
|
||||
double rand2 = Math.random();
|
||||
if (rand2 < MID_THRESHOLD) {
|
||||
replaces[i] = FILL_NUM;
|
||||
} else if (rand2 < HIGH_THRESHOLD) {
|
||||
masks[i] = 1;
|
||||
} else {
|
||||
replaces[i] = (int) (Math.random() * VOCAB_SIZE);
|
||||
}
|
||||
}
|
||||
if (isKeptFirst) {
|
||||
masks[i] = 1;
|
||||
replaces[i] = 0;
|
||||
}
|
||||
if (isKeptLast) {
|
||||
masks[feature.seqLen - 1] = 1;
|
||||
replaces[feature.seqLen - 1] = 0;
|
||||
}
|
||||
inputIds[i] = inputIds[i] * masks[i] + replaces[i];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* get feature
|
||||
*
|
||||
* @param tokens input tokens
|
||||
* @param label input label
|
||||
* @return feature
|
||||
*/
|
||||
public Optional<Feature> getFeatures(List<Integer> tokens, String label) {
|
||||
if (tokens == null || label == null) {
|
||||
LOGGER.warning("tokens or label is null");
|
||||
return Optional.empty();
|
||||
}
|
||||
if (!labelMap.containsKey(label)) {
|
||||
return Optional.empty();
|
||||
}
|
||||
int[] segmentIds = new int[maxSeqLen];
|
||||
Arrays.fill(segmentIds, 0);
|
||||
int[] masks = new int[maxSeqLen];
|
||||
Arrays.fill(masks, 0);
|
||||
Arrays.fill(masks, 0, tokens.size(), 1); // tokens size can ensure less than masks
|
||||
int[] inputIds = new int[maxSeqLen];
|
||||
Arrays.fill(inputIds, 0);
|
||||
for (int i = 0; i < tokens.size(); i++) {
|
||||
inputIds[i] = tokens.get(i);
|
||||
}
|
||||
return Optional.of(new Feature(inputIds, masks, segmentIds, labelMap.get(label), tokens.size()));
|
||||
}
|
||||
|
||||
/**
|
||||
* tokenize text to tokens
|
||||
*
|
||||
* @param text input tokens
|
||||
* @param isTrainMode if work in train mod
|
||||
* @return tokens
|
||||
*/
|
||||
public List<Integer> tokenize(String text, boolean isTrainMode) {
|
||||
if (text == null) {
|
||||
LOGGER.warning("text is empty,skip it");
|
||||
return new ArrayList<>();
|
||||
}
|
||||
String[] splitTokens = splitText(text);
|
||||
List<String> wordPieceTokens = wordPieceTokenize(splitTokens);
|
||||
return convertTokensToIds(wordPieceTokens, isTrainMode); // trainMod need cyclicTrunc
|
||||
}
|
||||
}
|
|
@ -1,111 +0,0 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.example.lenet;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.model.Callback;
|
||||
import com.mindspore.flclient.model.CommonUtils;
|
||||
import com.mindspore.flclient.model.Status;
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.MSTensor;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Defining the Callback calculate classifier model.
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class ClassifierAccuracyCallback extends Callback {
|
||||
private static final Logger logger = Logger.getLogger(ClassifierAccuracyCallback.class.toString());
|
||||
private final int numOfClass;
|
||||
private final int batchSize;
|
||||
private final List<Integer> targetLabels;
|
||||
private float accuracy;
|
||||
|
||||
/**
|
||||
* Defining a constructor of ClassifierAccuracyCallback.
|
||||
*/
|
||||
public ClassifierAccuracyCallback(LiteSession session, int batchSize, int numOfClass, List<Integer> targetLabels) {
|
||||
super(session);
|
||||
this.batchSize = batchSize;
|
||||
this.numOfClass = numOfClass;
|
||||
this.targetLabels = targetLabels;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get eval accuracy.
|
||||
*
|
||||
* @return accuracy.
|
||||
*/
|
||||
public float getAccuracy() {
|
||||
return accuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status stepBegin() {
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status stepEnd() {
|
||||
Status status = calAccuracy();
|
||||
if (status != Status.SUCCESS) {
|
||||
return status;
|
||||
}
|
||||
steps++;
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status epochBegin() {
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status epochEnd() {
|
||||
logger.info(Common.addTag("average accuracy:" + steps + ",acc is:" + accuracy / steps));
|
||||
accuracy = accuracy / steps;
|
||||
steps = 0;
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
private Status calAccuracy() {
|
||||
if (targetLabels == null || targetLabels.isEmpty()) {
|
||||
logger.severe(Common.addTag("labels cannot be null"));
|
||||
return Status.NULLPTR;
|
||||
}
|
||||
Optional<MSTensor> outputTensor = searchOutputsForSize(batchSize * numOfClass);
|
||||
if (!outputTensor.isPresent()) {
|
||||
return Status.NULLPTR;
|
||||
}
|
||||
float[] scores = outputTensor.get().getFloatData();
|
||||
int hitCounts = 0;
|
||||
for (int b = 0; b < batchSize; b++) {
|
||||
int predictIdx = CommonUtils.getMaxScoreIndex(scores, numOfClass * b, numOfClass * b + numOfClass);
|
||||
if (targetLabels.get(b + steps * batchSize) == predictIdx) {
|
||||
hitCounts += 1;
|
||||
}
|
||||
}
|
||||
accuracy += ((float) (hitCounts) / batchSize);
|
||||
logger.info(Common.addTag("steps:" + steps + ",acc is:" + (float) (hitCounts) / batchSize));
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,188 +0,0 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.example.lenet;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.model.DataSet;
|
||||
import com.mindspore.flclient.model.FileUtil;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Defining the minist dataset for lenet.
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class LenetDataSet extends DataSet {
|
||||
private static final Logger logger = Logger.getLogger(DataSet.class.toString());
|
||||
|
||||
private static final int IMAGE_SIZE = 32 * 32 * 3;
|
||||
|
||||
private static final int FLOAT_BYTE_SIZE = 4;
|
||||
|
||||
private byte[] imageArray;
|
||||
|
||||
private int[] labelArray;
|
||||
|
||||
private final int numOfClass;
|
||||
|
||||
private List<Integer> targetLabels;
|
||||
|
||||
/**
|
||||
* Defining a constructor of lenet dataset.
|
||||
*/
|
||||
public LenetDataSet(int numOfClass) {
|
||||
this.numOfClass = numOfClass;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get dataset labels.
|
||||
*
|
||||
* @return dataset target labels.
|
||||
*/
|
||||
public List<Integer> getTargetLabels() {
|
||||
return targetLabels;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillInputBuffer(List<ByteBuffer> inputsBuffer, int batchIdx) {
|
||||
// infer,train,eval model is same one
|
||||
if (inputsBuffer.size() != 2) {
|
||||
logger.severe(Common.addTag("input size error"));
|
||||
return;
|
||||
}
|
||||
if (batchIdx > batchNum) {
|
||||
logger.severe(Common.addTag("fill model image input failed"));
|
||||
return;
|
||||
}
|
||||
for (ByteBuffer inputBuffer : inputsBuffer) {
|
||||
inputBuffer.clear();
|
||||
}
|
||||
ByteBuffer imageBuffer = inputsBuffer.get(0);
|
||||
ByteBuffer labelIdBuffer = inputsBuffer.get(1);
|
||||
int imageInputBytes = IMAGE_SIZE * batchSize * Float.BYTES;
|
||||
for (int i = 0; i < imageInputBytes; i++) {
|
||||
imageBuffer.put(imageArray[batchIdx * imageInputBytes + i]);
|
||||
}
|
||||
if (labelArray == null) {
|
||||
return;
|
||||
}
|
||||
int labelSize = batchSize * numOfClass;
|
||||
if ((batchIdx + 1) * labelSize - 1 >= labelArray.length) {
|
||||
logger.severe(Common.addTag("fill model label input failed"));
|
||||
return;
|
||||
}
|
||||
labelIdBuffer.clear();
|
||||
for (int i = 0; i < labelSize; i++) {
|
||||
labelIdBuffer.putFloat(labelArray[batchIdx * labelSize + i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void shuffle() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void padding() {
|
||||
if (labelArray == null) // infer model
|
||||
{
|
||||
labelArray = new int[imageArray.length * numOfClass / (IMAGE_SIZE * Float.BYTES)];
|
||||
Arrays.fill(labelArray, 0);
|
||||
}
|
||||
int curSize = labelArray.length / numOfClass;
|
||||
int modSize = curSize - curSize / batchSize * batchSize;
|
||||
int padSize = modSize != 0 ? batchSize * numOfClass - modSize : 0;
|
||||
if (padSize != 0) {
|
||||
int[] padLabelArray = new int[labelArray.length + padSize * numOfClass];
|
||||
byte[] padImageArray = new byte[imageArray.length + padSize * IMAGE_SIZE * Float.BYTES];
|
||||
System.arraycopy(labelArray, 0, padLabelArray, 0, labelArray.length);
|
||||
System.arraycopy(imageArray, 0, padImageArray, 0, imageArray.length);
|
||||
for (int i = 0; i < padSize; i++) {
|
||||
int idx = (int) (Math.random() * curSize);
|
||||
System.arraycopy(labelArray, idx * numOfClass, padLabelArray, labelArray.length + i * numOfClass,
|
||||
numOfClass);
|
||||
System.arraycopy(imageArray, idx * IMAGE_SIZE * Float.BYTES, padImageArray,
|
||||
padImageArray.length + i * IMAGE_SIZE * Float.BYTES, IMAGE_SIZE * Float.BYTES);
|
||||
}
|
||||
labelArray = padLabelArray;
|
||||
imageArray = padImageArray;
|
||||
}
|
||||
sampleSize = curSize + padSize;
|
||||
batchNum = sampleSize / batchSize;
|
||||
setPredictLabels(labelArray);
|
||||
logger.info(Common.addTag("total samples:" + sampleSize));
|
||||
logger.info(Common.addTag("total batchNum:" + batchNum));
|
||||
}
|
||||
|
||||
private void setPredictLabels(int[] labelArray) {
|
||||
int labels_num = labelArray.length / numOfClass;
|
||||
targetLabels = new ArrayList<>(labels_num);
|
||||
for (int i = 0; i < labels_num; i++) {
|
||||
int label = getMaxIndex(labelArray, numOfClass * i, numOfClass * (i + 1));
|
||||
if (label == -1) {
|
||||
logger.severe(Common.addTag("get max index failed"));
|
||||
}
|
||||
targetLabels.add(label);
|
||||
}
|
||||
}
|
||||
|
||||
private int getMaxIndex(int[] nums, int begin, int end) {
|
||||
for (int i = begin; i < end; i++) {
|
||||
if (nums[i] == 1) {
|
||||
return i - begin;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dataPreprocess(List<String> files) {
|
||||
String labelFile = "";
|
||||
String imageFile;
|
||||
if (files.size() == 2) {
|
||||
imageFile = files.get(0);
|
||||
labelFile = files.get(1);
|
||||
} else if (files.size() == 1) {
|
||||
imageFile = files.get(0);
|
||||
} else {
|
||||
logger.severe(Common.addTag("files size error"));
|
||||
return -1;
|
||||
}
|
||||
imageArray = FileUtil.readBinFile(imageFile);
|
||||
if (labelFile != null && !labelFile.isEmpty()) {
|
||||
byte[] labelByteArray = FileUtil.readBinFile(labelFile);
|
||||
targetLabels = new ArrayList<>(labelByteArray.length / FLOAT_BYTE_SIZE);
|
||||
// model labels use one hot
|
||||
labelArray = new int[labelByteArray.length / FLOAT_BYTE_SIZE * numOfClass];
|
||||
Arrays.fill(labelArray, 0);
|
||||
int offset = 0;
|
||||
for (int i = 0; i < labelByteArray.length; i += FLOAT_BYTE_SIZE) {
|
||||
labelArray[offset * numOfClass + labelByteArray[i]] = 1;
|
||||
offset++;
|
||||
}
|
||||
} else {
|
||||
labelArray = null; // labelArray may be initialized from train
|
||||
}
|
||||
sampleSize = imageArray.length/IMAGE_SIZE/FLOAT_BYTE_SIZE;
|
||||
return sampleSize;
|
||||
}
|
||||
}
|
|
@ -1,85 +0,0 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.example.lenet;
|
||||
|
||||
import com.mindspore.flclient.model.Callback;
|
||||
import com.mindspore.flclient.model.CommonUtils;
|
||||
import com.mindspore.flclient.model.Status;
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.MSTensor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Defining the Callback get model predict result.
|
||||
*
|
||||
* @since v1.0
|
||||
*/
|
||||
public class PredictCallback extends Callback {
|
||||
private final List<Integer> predictResults = new ArrayList<>();
|
||||
private final int numOfClass;
|
||||
private final int batchSize;
|
||||
|
||||
/**
|
||||
* Defining a constructor of predict callback.
|
||||
*/
|
||||
public PredictCallback(LiteSession session, int batchSize, int numOfClass) {
|
||||
super(session);
|
||||
this.batchSize = batchSize;
|
||||
this.numOfClass = numOfClass;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get predict results.
|
||||
*
|
||||
* @return predict result.
|
||||
*/
|
||||
public List<Integer> getPredictResults() {
|
||||
return predictResults;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status stepBegin() {
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status stepEnd() {
|
||||
Optional<MSTensor> outputTensor = searchOutputsForSize(batchSize * numOfClass);
|
||||
if (!outputTensor.isPresent()) {
|
||||
return Status.FAILED;
|
||||
}
|
||||
float[] scores = outputTensor.get().getFloatData();
|
||||
for (int b = 0; b < batchSize; b++) {
|
||||
int predictIdx = CommonUtils.getMaxScoreIndex(scores, numOfClass * b, numOfClass * b + numOfClass);
|
||||
predictResults.add(predictIdx);
|
||||
}
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status epochBegin() {
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status epochEnd() {
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
}
|
|
@ -17,7 +17,6 @@
|
|||
package com.mindspore.flclient.model;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.example.lenet.ClassifierAccuracyCallback;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.logging.Logger;
|
||||
|
@ -25,7 +24,7 @@ import java.util.logging.Logger;
|
|||
public class ClientManager {
|
||||
private static final HashMap<String, Client> clientMaps = new HashMap<>();
|
||||
|
||||
private static final Logger logger = Logger.getLogger(ClassifierAccuracyCallback.class.toString());
|
||||
private static final Logger logger = Logger.getLogger(ClientManager.class.toString());
|
||||
|
||||
/**
|
||||
* Register client.
|
||||
|
|
|
@ -69,21 +69,21 @@ public abstract class DataSet {
|
|||
* @param files data files.
|
||||
* @return preprocess status.
|
||||
*/
|
||||
public abstract int dataPreprocess(List<String> files);
|
||||
public abstract Status dataPreprocess(List<String> files);
|
||||
|
||||
/**
|
||||
* Init dataset.
|
||||
*
|
||||
* @param files data files.
|
||||
* @return dataset size.
|
||||
* @return init status.
|
||||
*/
|
||||
public int init(List<String> files) {
|
||||
int status = dataPreprocess(files);
|
||||
if (status != 0) {
|
||||
public Status init(List<String> files) {
|
||||
Status status = dataPreprocess(files);
|
||||
if (status != Status.SUCCESS) {
|
||||
logger.severe(Common.addTag("data preprocess failed"));
|
||||
return status;
|
||||
}
|
||||
shuffle();
|
||||
return sampleSize;
|
||||
return status;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue