bert ner for msra dataset
This commit is contained in:
parent
8e4a6842ff
commit
8b1a8a6bc1
|
@ -27,6 +27,7 @@
|
|||
- [Evaluation](#evaluation)
|
||||
- [evaluation on cola dataset when running on Ascend](#evaluation-on-cola-dataset-when-running-on-ascend)
|
||||
- [evaluation on cluener dataset when running on Ascend](#evaluation-on-cluener-dataset-when-running-on-ascend)
|
||||
- [evaluation on msra dataset when running on Ascend](#evaluation-on-msra-dataset-when-running-on-ascend)
|
||||
- [evaluation on squad v1.1 dataset when running on Ascend](#evaluation-on-squad-v11-dataset-when-running-on-ascend)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
|
@ -215,7 +216,7 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol
|
|||
├─bert_for_finetune.py # backbone code of network
|
||||
├─bert_for_pre_training.py # backbone code of network
|
||||
├─bert_model.py # backbone code of network
|
||||
├─clue_classification_dataset_precess.py # data preprocessing
|
||||
├─finetune_data_preprocess.py # data preprocessing
|
||||
├─cluner_evaluation.py # evaluation for cluner
|
||||
├─config.py # parameter configuration for pretraining
|
||||
├─CRF.py # assessment method for clue dataset
|
||||
|
@ -301,6 +302,7 @@ options:
|
|||
--load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval
|
||||
--train_data_file_path ner tfrecord for training. E.g., train.tfrecord
|
||||
--eval_data_file_path ner tfrecord for predictions if f1 is used to evaluate result, ner json for predictions if clue_benchmark is used to evaluate result
|
||||
--dataset_format dataset format, support mindrecord or tfrecord
|
||||
--schema_file_path path to datafile schema file
|
||||
|
||||
usage: run_squad.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN] [----do_eval DO_EVAL]
|
||||
|
@ -544,6 +546,30 @@ Recall 0.948683
|
|||
F1 0.920507
|
||||
```
|
||||
|
||||
#### evaluation on msra dataset when running on Ascend
|
||||
|
||||
For preprocess, you can first convert the original txt format of MSRA dataset into mindrecord by run the command as below:
|
||||
|
||||
```python
|
||||
python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len
|
||||
```
|
||||
|
||||
For finetune and evaluation, just do
|
||||
|
||||
```bash
|
||||
bash scripts/ner.sh
|
||||
```
|
||||
|
||||
The command above will run in the background, you can view training logs in ner_log.txt.
|
||||
|
||||
If you choose SpanF1 as assessment method and mode use_crf is set to be "true", the result will be as follows if evaluation is done after finetuning 10 epoches:
|
||||
|
||||
```text
|
||||
Precision 0.953826
|
||||
Recall 0.957749
|
||||
F1 0.955784
|
||||
```
|
||||
|
||||
#### evaluation on squad v1.1 dataset when running on Ascend
|
||||
|
||||
```bash
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
- [用法](#用法-1)
|
||||
- [Ascend处理器上运行后评估cola数据集](#ascend处理器上运行后评估cola数据集)
|
||||
- [Ascend处理器上运行后评估cluener数据集](#ascend处理器上运行后评估cluener数据集)
|
||||
- [Ascend处理器上运行后评估msra数据集](#ascend处理器上运行后评估msra数据集)
|
||||
- [Ascend处理器上运行后评估squad v1.1数据集](#ascend处理器上运行后评估squad-v11数据集)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
|
@ -215,7 +216,7 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol
|
|||
├─bert_for_finetune.py # 网络骨干编码
|
||||
├─bert_for_pre_training.py # 网络骨干编码
|
||||
├─bert_model.py # 网络骨干编码
|
||||
├─clue_classification_dataset_precess.py # 数据预处理
|
||||
├─finetune_data_preprocess.py # 数据预处理
|
||||
├─cluner_evaluation.py # 评估线索生成工具
|
||||
├─config.py # 预训练参数配置
|
||||
├─CRF.py # 线索数据集评估方法
|
||||
|
@ -299,6 +300,7 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol
|
|||
--load_finetune_checkpoint_path 如仅执行评估,提供微调检查点保存路径
|
||||
--train_data_file_path 用于保存训练数据的TFRecord文件,如train.tfrecord文件
|
||||
--eval_data_file_path 如采用f1来评估结果,则为TFRecord文件保存预测;如采用clue_benchmark来评估结果,则为JSON文件保存预测
|
||||
--dataset_format 数据集格式,支持tfrecord和mindrecord格式
|
||||
--schema_file_path 模式文件保存路径
|
||||
|
||||
用法:run_squad.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN] [----do_eval DO_EVAL]
|
||||
|
@ -508,6 +510,29 @@ Recall 0.948683
|
|||
F1 0.920507
|
||||
```
|
||||
|
||||
#### Ascend处理器上运行后评估msra数据集
|
||||
|
||||
您可以采用如下方式,先将MSRA数据集的原始格式在预处理流程中转换为mindrecord格式以提升性能:
|
||||
|
||||
```python
|
||||
python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len
|
||||
```
|
||||
|
||||
此后,您可以进行微调再训练和推理流程,
|
||||
|
||||
```bash
|
||||
bash scripts/ner.sh
|
||||
```
|
||||
|
||||
以上命令后台运行,您可以在ner_log.txt中查看训练日志。
|
||||
如您选择SpanF1作为评估方法并且模型结构中配置CRF模式,在微调训练10个epoch之后进行推理,可得到如下结果:
|
||||
|
||||
```text
|
||||
Precision 0.953826
|
||||
Recall 0.957749
|
||||
F1 0.955784
|
||||
```
|
||||
|
||||
#### Ascend处理器上运行后评估squad v1.1数据集
|
||||
|
||||
```bash
|
||||
|
|
|
@ -23,7 +23,7 @@ from src.bert_for_finetune import BertFinetuneCell, BertNER
|
|||
from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
|
||||
from src.dataset import create_ner_dataset
|
||||
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate, convert_labels_to_index
|
||||
from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation
|
||||
from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation, SpanF1
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
|
@ -86,7 +86,7 @@ def eval_result_print(assessment_method="accuracy", callback=None):
|
|||
if assessment_method == "accuracy":
|
||||
print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
|
||||
callback.acc_num / callback.total_num))
|
||||
elif assessment_method == "f1":
|
||||
elif assessment_method in ("f1", "spanf1"):
|
||||
print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
|
||||
print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
|
||||
print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN)))
|
||||
|
@ -118,6 +118,8 @@ def do_eval(dataset=None, network=None, use_crf="", num_class=41, assessment_met
|
|||
callback = Accuracy()
|
||||
elif assessment_method == "f1":
|
||||
callback = F1((use_crf.lower() == "true"), num_class)
|
||||
elif assessment_method == "spanf1":
|
||||
callback = SpanF1((use_crf.lower() == "true"), tag_to_index)
|
||||
elif assessment_method == "mcc":
|
||||
callback = MCC()
|
||||
elif assessment_method == "spearman_correlation":
|
||||
|
@ -143,8 +145,8 @@ def parse_args():
|
|||
parser = argparse.ArgumentParser(description="run ner")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"],
|
||||
help="Device type, default is Ascend")
|
||||
parser.add_argument("--assessment_method", type=str, default="F1", choices=["F1", "clue_benchmark"],
|
||||
help="assessment_method include: [F1, clue_benchmark], default is F1")
|
||||
parser.add_argument("--assessment_method", type=str, default="F1", choices=["F1", "clue_benchmark", "SpanF1"],
|
||||
help="assessment_method include: [F1, clue_benchmark, SpanF1], default is F1")
|
||||
parser.add_argument("--do_train", type=str, default="false", choices=["true", "false"],
|
||||
help="Eable train, default is false")
|
||||
parser.add_argument("--do_eval", type=str, default="false", choices=["true", "false"],
|
||||
|
@ -169,6 +171,8 @@ def parse_args():
|
|||
help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--eval_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--dataset_format", type=str, default="mindrecord", choices=["mindrecord", "tfrecord"],
|
||||
help="Dataset format, support mindrecord or tfrecord")
|
||||
parser.add_argument("--schema_file_path", type=str, default="",
|
||||
help="Schema path, it is better to use absolute path")
|
||||
args_opt = parser.parse_args()
|
||||
|
@ -225,7 +229,7 @@ def run_ner():
|
|||
tag_to_index=tag_to_index, dropout_prob=0.1)
|
||||
ds = create_ner_dataset(batch_size=args_opt.train_batch_size, repeat_count=1,
|
||||
assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path,
|
||||
schema_file_path=args_opt.schema_file_path, dataset_format=args_opt.dataset_format,
|
||||
do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
|
||||
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
|
||||
|
||||
|
@ -240,7 +244,7 @@ def run_ner():
|
|||
if args_opt.do_eval.lower() == "true":
|
||||
ds = create_ner_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1,
|
||||
assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path,
|
||||
schema_file_path=args_opt.schema_file_path, dataset_format=args_opt.dataset_format,
|
||||
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
|
||||
do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method,
|
||||
args_opt.eval_data_file_path, load_finetune_checkpoint_path, args_opt.vocab_file_path,
|
||||
|
|
|
@ -18,7 +18,7 @@ echo "==========================================================================
|
|||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_ner.sh"
|
||||
echo "for example: bash scripts/run_ner.sh"
|
||||
echo "assessment_method include: [F1, clue_benchmark]"
|
||||
echo "assessment_method include: [F1, SpanF1, clue_benchmark]"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
mkdir -p ms_log
|
||||
|
@ -46,4 +46,5 @@ python ${PROJECT_DIR}/../run_ner.py \
|
|||
--load_finetune_checkpoint_path="" \
|
||||
--train_data_file_path="" \
|
||||
--eval_data_file_path="" \
|
||||
--dataset_format="tfrecord" \
|
||||
--schema_file_path="" > ner_log.txt 2>&1 &
|
||||
|
|
|
@ -68,6 +68,74 @@ class F1():
|
|||
self.FP += np.sum(pos_eva&(~pos_label))
|
||||
self.FN += np.sum((~pos_eva)&pos_label)
|
||||
|
||||
|
||||
class SpanF1():
|
||||
'''
|
||||
calculate F1、precision and recall score in span manner for NER
|
||||
'''
|
||||
def __init__(self, use_crf=False, label2id=None):
|
||||
self.TP = 0
|
||||
self.FP = 0
|
||||
self.FN = 0
|
||||
self.use_crf = use_crf
|
||||
self.label2id = label2id
|
||||
if label2id is None:
|
||||
raise ValueError("label2id info should not be empty")
|
||||
self.id2label = {}
|
||||
for key, value in label2id.items():
|
||||
self.id2label[value] = key
|
||||
|
||||
def tag2span(self, ids):
|
||||
'''
|
||||
conbert ids list to span mode
|
||||
'''
|
||||
labels = np.array([self.id2label[id] for id in ids])
|
||||
spans = []
|
||||
prev_label = None
|
||||
for idx, tag in enumerate(labels):
|
||||
tag = tag.lower()
|
||||
cur_label, label = tag[:1], tag[2:]
|
||||
if cur_label in ('b', 's'):
|
||||
spans.append((label, [idx, idx]))
|
||||
elif cur_label in ('m', 'e') and prev_label in ('b', 'm') and label == spans[-1][0]:
|
||||
spans[-1][1][1] = idx
|
||||
elif cur_label == 'o':
|
||||
pass
|
||||
else:
|
||||
spans.append((label, [idx, idx]))
|
||||
prev_label = cur_label
|
||||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans]
|
||||
|
||||
|
||||
def update(self, logits, labels):
|
||||
'''
|
||||
update span F1 score
|
||||
'''
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
if self.use_crf:
|
||||
backpointers, best_tag_id = logits
|
||||
best_path = postprocess(backpointers, best_tag_id)
|
||||
logit_id = []
|
||||
for ele in best_path:
|
||||
logit_id.extend(ele)
|
||||
else:
|
||||
logits = logits.asnumpy()
|
||||
logit_id = np.argmax(logits, axis=-1)
|
||||
logit_id = np.reshape(logit_id, -1)
|
||||
|
||||
label_spans = self.tag2span(labels)
|
||||
pred_spans = self.tag2span(logit_id)
|
||||
for span in pred_spans:
|
||||
if span in label_spans:
|
||||
self.TP += 1
|
||||
label_spans.remove(span)
|
||||
else:
|
||||
self.FP += 1
|
||||
for span in label_spans:
|
||||
self.FN += 1
|
||||
|
||||
|
||||
class MCC():
|
||||
'''
|
||||
Calculate Matthews Correlation Coefficient
|
||||
|
|
|
@ -52,25 +52,30 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None,
|
|||
return data_set
|
||||
|
||||
|
||||
def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy",
|
||||
data_file_path=None, schema_file_path=None, do_shuffle=True):
|
||||
def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", data_file_path=None,
|
||||
dataset_format="mindrecord", schema_file_path=None, do_shuffle=True):
|
||||
"""create finetune or evaluation dataset"""
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
data_set = ds.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"],
|
||||
shuffle=do_shuffle)
|
||||
if dataset_format == "mindrecord":
|
||||
dataset = ds.MindDataset([data_file_path],
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"],
|
||||
shuffle=do_shuffle)
|
||||
else:
|
||||
dataset = ds.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"],
|
||||
shuffle=do_shuffle)
|
||||
if assessment_method == "Spearman_correlation":
|
||||
type_cast_op_float = C.TypeCast(mstype.float32)
|
||||
data_set = data_set.map(operations=type_cast_op_float, input_columns="label_ids")
|
||||
dataset = dataset.map(operations=type_cast_op_float, input_columns="label_ids")
|
||||
else:
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label_ids")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="segment_ids")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="input_mask")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="input_ids")
|
||||
data_set = data_set.repeat(repeat_count)
|
||||
dataset = dataset.map(operations=type_cast_op, input_columns="label_ids")
|
||||
dataset = dataset.map(operations=type_cast_op, input_columns="segment_ids")
|
||||
dataset = dataset.map(operations=type_cast_op, input_columns="input_mask")
|
||||
dataset = dataset.map(operations=type_cast_op, input_columns="input_ids")
|
||||
dataset = dataset.repeat(repeat_count)
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
return data_set
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
return dataset
|
||||
|
||||
|
||||
def create_classification_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy",
|
||||
|
|
|
@ -18,12 +18,14 @@ sample script of processing CLUE classification dataset using mindspore.dataset.
|
|||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.text as text
|
||||
import mindspore.dataset.transforms.c_transforms as ops
|
||||
from utils import convert_labels_to_index
|
||||
|
||||
|
||||
def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False,
|
||||
|
@ -135,3 +137,93 @@ def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage
|
|||
dataset = dataset.map(operations=ops.Mask(ops.Relational.NE, 0, mstype.int32), input_columns=["mask_ids"])
|
||||
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
|
||||
return dataset
|
||||
|
||||
|
||||
def process_cluener_msra(data_file):
|
||||
"""process MSRA dataset for CLUE"""
|
||||
content = []
|
||||
labels = []
|
||||
for line in open(data_file):
|
||||
line = line.strip()
|
||||
if line:
|
||||
word = line.split("\t")[0]
|
||||
if len(line.split("\t")) == 1:
|
||||
label = "O"
|
||||
else:
|
||||
label = line.split("\t")[1].split("\n")[0]
|
||||
if label[0] != "O":
|
||||
label = label[0] + "_" + label[2:]
|
||||
if label[0] == "I":
|
||||
label = "M" + label[1:]
|
||||
content.append(word)
|
||||
labels.append(label)
|
||||
else:
|
||||
for i in range(1, len(labels) - 1):
|
||||
if labels[i][0] == "B" and labels[i+1][0] != "M":
|
||||
labels[i] = "S" + labels[i][1:]
|
||||
elif labels[i][0] == "M" and labels[i+1][0] != labels[i][0]:
|
||||
labels[i] = "E" + labels[i][1:]
|
||||
last = len(labels) - 1
|
||||
if labels[last][0] == "B":
|
||||
labels[last] = "S" + labels[last][1:]
|
||||
elif labels[last][0] == "M":
|
||||
labels[last] = "E" + labels[last][1:]
|
||||
|
||||
yield (np.array("".join(content)), np.array(list(labels)))
|
||||
content.clear()
|
||||
labels.clear()
|
||||
continue
|
||||
|
||||
|
||||
def process_msra_clue_dataset(data_dir, label_list, bert_vocab_path, max_seq_len=128):
|
||||
"""Process MSRA dataset"""
|
||||
### Loading MSRA from CLUEDataset
|
||||
dataset = ds.GeneratorDataset(process_cluener_msra(data_dir), column_names=['text', 'label'])
|
||||
|
||||
### Processing label
|
||||
label_vocab = text.Vocab.from_list(label_list)
|
||||
label_lookup = text.Lookup(label_vocab)
|
||||
dataset = dataset.map(operations=label_lookup, input_columns="label", output_columns="label_ids")
|
||||
dataset = dataset.map(operations=ops.Concatenate(prepend=np.array([0], dtype='i')),
|
||||
input_columns=["label_ids"])
|
||||
dataset = dataset.map(operations=ops.Slice(slice(0, max_seq_len)), input_columns=["label_ids"])
|
||||
dataset = dataset.map(operations=ops.PadEnd([max_seq_len], 0), input_columns=["label_ids"])
|
||||
### Processing sentence
|
||||
vocab = text.Vocab.from_file(bert_vocab_path)
|
||||
lookup = text.Lookup(vocab, unknown_token='[UNK]')
|
||||
unicode_char_tokenizer = text.UnicodeCharTokenizer()
|
||||
dataset = dataset.map(operations=unicode_char_tokenizer, input_columns=["text"], output_columns=["sentence"])
|
||||
dataset = dataset.map(operations=ops.Slice(slice(0, max_seq_len-2)), input_columns=["sentence"])
|
||||
dataset = dataset.map(operations=ops.Concatenate(prepend=np.array(["[CLS]"], dtype='S'),
|
||||
append=np.array(["[SEP]"], dtype='S')), input_columns=["sentence"])
|
||||
dataset = dataset.map(operations=lookup, input_columns=["sentence"], output_columns=["input_ids"])
|
||||
dataset = dataset.map(operations=ops.PadEnd([max_seq_len], 0), input_columns=["input_ids"])
|
||||
dataset = dataset.map(operations=ops.Duplicate(), input_columns=["input_ids"],
|
||||
output_columns=["input_ids", "input_mask"],
|
||||
column_order=["input_ids", "input_mask", "label_ids"])
|
||||
dataset = dataset.map(operations=ops.Mask(ops.Relational.NE, 0, mstype.int32), input_columns=["input_mask"])
|
||||
dataset = dataset.map(operations=ops.Duplicate(), input_columns=["input_ids"],
|
||||
output_columns=["input_ids", "segment_ids"],
|
||||
column_order=["input_ids", "input_mask", "segment_ids", "label_ids"])
|
||||
dataset = dataset.map(operations=ops.Fill(0), input_columns=["segment_ids"])
|
||||
return dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="create mindrecord")
|
||||
parser.add_argument("--data_dir", type=str, default="", help="dataset path")
|
||||
parser.add_argument("--vocab_file", type=str, default="", help="Vocab file path")
|
||||
parser.add_argument("--max_seq_len", type=int, default=128, help="Sequence length")
|
||||
parser.add_argument("--save_path", type=str, default="./my.mindrecord", help="Path to save mindrecord")
|
||||
parser.add_argument("--label2id", type=str, default="",
|
||||
help="Label2id file path, must be set for cluener2020 task")
|
||||
args_opt = parser.parse_args()
|
||||
if args_opt.label2id == "":
|
||||
raise ValueError("label2id should not be empty")
|
||||
labels_list = []
|
||||
with open(args_opt.label2id) as f:
|
||||
for tag in f:
|
||||
labels_list.append(tag.strip())
|
||||
tag_to_index = list(convert_labels_to_index(labels_list).keys())
|
||||
ds = process_msra_clue_dataset(args_opt.data_dir, tag_to_index, args_opt.vocab_file, args_opt.max_seq_len)
|
||||
ds.save(args_opt.save_path)
|
Loading…
Reference in New Issue