!19599 [MS][LITE] r1.3 modify adbert to emotion classify demo

Merge pull request !19599 from zhengjun10/r1.3
This commit is contained in:
i-robot 2021-07-08 01:51:50 +00:00 committed by Gitee
commit aaa80736a6
3 changed files with 19 additions and 20 deletions

View File

@ -28,7 +28,7 @@ import java.util.logging.Logger;
public class AdBert extends TrainModel {
private static final Logger logger = Logger.getLogger(AdBert.class.toString());
private static final int NUM_OF_CLASS = 5;
private static final int NUM_OF_CLASS = 4;
List<Feature> features;
@ -66,7 +66,7 @@ public class AdBert extends TrainModel {
tokenIdBufffer.order(ByteOrder.nativeOrder());
maskIdBufffer.order(ByteOrder.nativeOrder());
if (trainMod) {
labelIdBufffer = ByteBuffer.allocateDirect(inputSize * Integer.BYTES);
labelIdBufffer = ByteBuffer.allocateDirect(batchSize * Integer.BYTES);
labelIdBufffer.order(ByteOrder.nativeOrder());
}
numOfClass = NUM_OF_CLASS;
@ -93,14 +93,12 @@ public class AdBert extends TrainModel {
for (int j = 0; j < dataSize; j++) {
maskIdBufffer.putInt(feature.inputMasks[j]);
}
if(trainMod) {
labelIdBufffer.putInt(feature.labelIds);
}
if (!trainMod) {
labels.add(feature.labelIds);
}
if (trainMod) {
for (int j = 0; j < dataSize; j++) {
labelIdBufffer.putInt(feature.inputIds[j]);
}
}
}
List<MSTensor> inputs = trainSession.getInputs();
@ -109,10 +107,10 @@ public class AdBert extends TrainModel {
MSTensor inputIdTensor;
MSTensor maskIdTensor;
if (trainMod) {
labelIdTensor = inputs.get(0);
tokenIdTensor = inputs.get(1);
inputIdTensor = inputs.get(2);
maskIdTensor = inputs.get(3);
labelIdTensor = inputs.get(3);
tokenIdTensor = inputs.get(0);
inputIdTensor = inputs.get(1);
maskIdTensor = inputs.get(2);
labelIdTensor.setData(labelIdBufffer);
} else {
tokenIdTensor = inputs.get(0);

View File

@ -33,13 +33,13 @@ public class CustomTokenizer {
private int maxInputChars = 100;
private String[] NotSplitStrs = {"UNK"};
private String unkToken = "[UNK]";
private int maxSeqLen = 16;
private int maxSeqLen = 8;
private int vocabSize = 11682;
private Map<String, Integer> labelMap = new HashMap<String, Integer>() {{
put("beauty", 0);
put("education", 1);
put("hotel", 2);
put("travel", 3);
put("good", 0);
put("leimu", 1);
put("xiaoku", 2);
put("xin", 3);
put("other", 4);
}};
@ -199,6 +199,10 @@ public class CustomTokenizer {
}
public Feature getFeatures(List<Integer> tokens, String label) {
if (!labelMap.containsKey(label)) {
logger.severe(Common.addTag("label map not contains label:" + label));
return null;
}
int[] segmentIds = new int[maxSeqLen];
Arrays.fill(segmentIds, 0);
int[] masks = new int[maxSeqLen];

View File

@ -62,11 +62,8 @@ public class DataSet {
List<Feature> features = new ArrayList<>(examples.size());
for (int i = 0; i < examples.size(); i++) {
List<Integer> tokens = customTokenizer.tokenize(examples.get(i), trainMod);
List<Integer> tokens = customTokenizer.tokenize(examples.get(i), false);
Feature feature = customTokenizer.getFeatures(tokens, labels.get(i));
if (trainMod) {
customTokenizer.addRandomMaskAndReplace(feature, true, true);
}
features.add(feature);
}
return features;