diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdBert.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdBert.java index c53c5047abc..fa2898009c2 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdBert.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdBert.java @@ -28,7 +28,7 @@ import java.util.logging.Logger; public class AdBert extends TrainModel { private static final Logger logger = Logger.getLogger(AdBert.class.toString()); - private static final int NUM_OF_CLASS = 5; + private static final int NUM_OF_CLASS = 4; List features; @@ -66,7 +66,7 @@ public class AdBert extends TrainModel { tokenIdBufffer.order(ByteOrder.nativeOrder()); maskIdBufffer.order(ByteOrder.nativeOrder()); if (trainMod) { - labelIdBufffer = ByteBuffer.allocateDirect(inputSize * Integer.BYTES); + labelIdBufffer = ByteBuffer.allocateDirect(batchSize * Integer.BYTES); labelIdBufffer.order(ByteOrder.nativeOrder()); } numOfClass = NUM_OF_CLASS; @@ -93,14 +93,12 @@ public class AdBert extends TrainModel { for (int j = 0; j < dataSize; j++) { maskIdBufffer.putInt(feature.inputMasks[j]); } + if(trainMod) { + labelIdBufffer.putInt(feature.labelIds); + } if (!trainMod) { labels.add(feature.labelIds); } - if (trainMod) { - for (int j = 0; j < dataSize; j++) { - labelIdBufffer.putInt(feature.inputIds[j]); - } - } } List inputs = trainSession.getInputs(); @@ -109,10 +107,10 @@ public class AdBert extends TrainModel { MSTensor inputIdTensor; MSTensor maskIdTensor; if (trainMod) { - labelIdTensor = inputs.get(0); - tokenIdTensor = inputs.get(1); - inputIdTensor = inputs.get(2); - maskIdTensor = inputs.get(3); + labelIdTensor = inputs.get(3); + tokenIdTensor = inputs.get(0); + inputIdTensor = inputs.get(1); + maskIdTensor = inputs.get(2); labelIdTensor.setData(labelIdBufffer); } else { tokenIdTensor = inputs.get(0); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/CustomTokenizer.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/CustomTokenizer.java index e98688513f6..977eb3313b5 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/CustomTokenizer.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/CustomTokenizer.java @@ -33,13 +33,13 @@ public class CustomTokenizer { private int maxInputChars = 100; private String[] NotSplitStrs = {"UNK"}; private String unkToken = "[UNK]"; - private int maxSeqLen = 16; + private int maxSeqLen = 8; private int vocabSize = 11682; private Map labelMap = new HashMap() {{ - put("beauty", 0); - put("education", 1); - put("hotel", 2); - put("travel", 3); + put("good", 0); + put("leimu", 1); + put("xiaoku", 2); + put("xin", 3); put("other", 4); }}; @@ -199,6 +199,10 @@ public class CustomTokenizer { } public Feature getFeatures(List tokens, String label) { + if (!labelMap.containsKey(label)) { + logger.severe(Common.addTag("label map not contains label:" + label)); + return null; + } int[] segmentIds = new int[maxSeqLen]; Arrays.fill(segmentIds, 0); int[] masks = new int[maxSeqLen]; diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/DataSet.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/DataSet.java index 5d752079797..1c2f982d943 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/DataSet.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/DataSet.java @@ -62,11 +62,8 @@ public class DataSet { List features = new ArrayList<>(examples.size()); for (int i = 0; i < examples.size(); i++) { - List tokens = customTokenizer.tokenize(examples.get(i), trainMod); + List tokens = customTokenizer.tokenize(examples.get(i), false); Feature feature = customTokenizer.getFeatures(tokens, labels.get(i)); - if (trainMod) { - customTokenizer.addRandomMaskAndReplace(feature, true, true); - } features.add(feature); } return features;