!19599 [MS][LITE] r1.3 modify adbert to emotion classify demo
Merge pull request !19599 from zhengjun10/r1.3
This commit is contained in:
commit
aaa80736a6
|
@ -28,7 +28,7 @@ import java.util.logging.Logger;
|
|||
public class AdBert extends TrainModel {
|
||||
private static final Logger logger = Logger.getLogger(AdBert.class.toString());
|
||||
|
||||
private static final int NUM_OF_CLASS = 5;
|
||||
private static final int NUM_OF_CLASS = 4;
|
||||
|
||||
List<Feature> features;
|
||||
|
||||
|
@ -66,7 +66,7 @@ public class AdBert extends TrainModel {
|
|||
tokenIdBufffer.order(ByteOrder.nativeOrder());
|
||||
maskIdBufffer.order(ByteOrder.nativeOrder());
|
||||
if (trainMod) {
|
||||
labelIdBufffer = ByteBuffer.allocateDirect(inputSize * Integer.BYTES);
|
||||
labelIdBufffer = ByteBuffer.allocateDirect(batchSize * Integer.BYTES);
|
||||
labelIdBufffer.order(ByteOrder.nativeOrder());
|
||||
}
|
||||
numOfClass = NUM_OF_CLASS;
|
||||
|
@ -93,14 +93,12 @@ public class AdBert extends TrainModel {
|
|||
for (int j = 0; j < dataSize; j++) {
|
||||
maskIdBufffer.putInt(feature.inputMasks[j]);
|
||||
}
|
||||
if(trainMod) {
|
||||
labelIdBufffer.putInt(feature.labelIds);
|
||||
}
|
||||
if (!trainMod) {
|
||||
labels.add(feature.labelIds);
|
||||
}
|
||||
if (trainMod) {
|
||||
for (int j = 0; j < dataSize; j++) {
|
||||
labelIdBufffer.putInt(feature.inputIds[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
List<MSTensor> inputs = trainSession.getInputs();
|
||||
|
@ -109,10 +107,10 @@ public class AdBert extends TrainModel {
|
|||
MSTensor inputIdTensor;
|
||||
MSTensor maskIdTensor;
|
||||
if (trainMod) {
|
||||
labelIdTensor = inputs.get(0);
|
||||
tokenIdTensor = inputs.get(1);
|
||||
inputIdTensor = inputs.get(2);
|
||||
maskIdTensor = inputs.get(3);
|
||||
labelIdTensor = inputs.get(3);
|
||||
tokenIdTensor = inputs.get(0);
|
||||
inputIdTensor = inputs.get(1);
|
||||
maskIdTensor = inputs.get(2);
|
||||
labelIdTensor.setData(labelIdBufffer);
|
||||
} else {
|
||||
tokenIdTensor = inputs.get(0);
|
||||
|
|
|
@ -33,13 +33,13 @@ public class CustomTokenizer {
|
|||
private int maxInputChars = 100;
|
||||
private String[] NotSplitStrs = {"UNK"};
|
||||
private String unkToken = "[UNK]";
|
||||
private int maxSeqLen = 16;
|
||||
private int maxSeqLen = 8;
|
||||
private int vocabSize = 11682;
|
||||
private Map<String, Integer> labelMap = new HashMap<String, Integer>() {{
|
||||
put("beauty", 0);
|
||||
put("education", 1);
|
||||
put("hotel", 2);
|
||||
put("travel", 3);
|
||||
put("good", 0);
|
||||
put("leimu", 1);
|
||||
put("xiaoku", 2);
|
||||
put("xin", 3);
|
||||
put("other", 4);
|
||||
}};
|
||||
|
||||
|
@ -199,6 +199,10 @@ public class CustomTokenizer {
|
|||
}
|
||||
|
||||
public Feature getFeatures(List<Integer> tokens, String label) {
|
||||
if (!labelMap.containsKey(label)) {
|
||||
logger.severe(Common.addTag("label map not contains label:" + label));
|
||||
return null;
|
||||
}
|
||||
int[] segmentIds = new int[maxSeqLen];
|
||||
Arrays.fill(segmentIds, 0);
|
||||
int[] masks = new int[maxSeqLen];
|
||||
|
|
|
@ -62,11 +62,8 @@ public class DataSet {
|
|||
|
||||
List<Feature> features = new ArrayList<>(examples.size());
|
||||
for (int i = 0; i < examples.size(); i++) {
|
||||
List<Integer> tokens = customTokenizer.tokenize(examples.get(i), trainMod);
|
||||
List<Integer> tokens = customTokenizer.tokenize(examples.get(i), false);
|
||||
Feature feature = customTokenizer.getFeatures(tokens, labels.get(i));
|
||||
if (trainMod) {
|
||||
customTokenizer.addRandomMaskAndReplace(feature, true, true);
|
||||
}
|
||||
features.add(feature);
|
||||
}
|
||||
return features;
|
||||
|
|
Loading…
Reference in New Issue