forked from mindspore-Ecosystem/mindspore
!3009 Add a sample script of data processing for fine-tuning BERT on CLUE classification dataset
Merge pull request !3009 from dessyang/master
This commit is contained in:
commit
b5dda1295b
|
@ -0,0 +1,153 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
sample script of processing CLUE classification dataset using mindspore.dataset.text for fine-tuning bert
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.text as text
|
||||
import mindspore.dataset.transforms.c_transforms as ops
|
||||
|
||||
|
||||
def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path,
|
||||
data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64):
|
||||
"""Process TNEWS dataset"""
|
||||
### Loading TNEWS from CLUEDataset
|
||||
assert data_usage in ['train', 'eval', 'test']
|
||||
if data_usage == 'train':
|
||||
dataset = ds.CLUEDataset(os.path.join(data_dir, "train.json"), task='TNEWS',
|
||||
usage=data_usage, shuffle=shuffle_dataset)
|
||||
elif data_usage == 'eval':
|
||||
dataset = ds.CLUEDataset(os.path.join(data_dir, "dev.json"), task='TNEWS',
|
||||
usage=data_usage, shuffle=shuffle_dataset)
|
||||
else:
|
||||
dataset = ds.CLUEDataset(os.path.join(data_dir, "test.json"), task='TNEWS',
|
||||
usage=data_usage, shuffle=shuffle_dataset)
|
||||
### Processing label
|
||||
if data_usage == 'test':
|
||||
dataset = dataset.map(input_columns=["id"], output_columns=["id", "label_id"],
|
||||
columns_order=["id", "label_id", "sentence"], operations=ops.Duplicate())
|
||||
dataset = dataset.map(input_columns=["label_id"], operations=ops.Fill(0))
|
||||
else:
|
||||
label_vocab = text.Vocab.from_list(label_list)
|
||||
label_lookup = text.Lookup(label_vocab)
|
||||
dataset = dataset.map(input_columns="label_desc", output_columns="label_id", operations=label_lookup)
|
||||
### Processing sentence
|
||||
vocab = text.Vocab.from_file(bert_vocab_path)
|
||||
tokenizer = text.BertTokenizer(vocab, lower_case=True)
|
||||
lookup = text.Lookup(vocab, unknown_token='[UNK]')
|
||||
dataset = dataset.map(input_columns=["sentence"], operations=tokenizer)
|
||||
dataset = dataset.map(input_columns=["sentence"], operations=ops.Slice(slice(0, max_seq_len)))
|
||||
dataset = dataset.map(input_columns=["sentence"],
|
||||
operations=ops.Concatenate(prepend=np.array(["[CLS]"], dtype='S'),
|
||||
append=np.array(["[SEP]"], dtype='S')))
|
||||
dataset = dataset.map(input_columns=["sentence"], output_columns=["text_ids"], operations=lookup)
|
||||
dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0))
|
||||
dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"],
|
||||
columns_order=["label_id", "text_ids", "mask_ids"], operations=ops.Duplicate())
|
||||
dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32))
|
||||
dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "segment_ids"],
|
||||
columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate())
|
||||
dataset = dataset.map(input_columns=["segment_ids"], operations=ops.Fill(0))
|
||||
dataset = dataset.batch(batch_size)
|
||||
label = []
|
||||
text_ids = []
|
||||
mask_ids = []
|
||||
segment_ids = []
|
||||
for data in dataset:
|
||||
label.append(data[0])
|
||||
text_ids.append(data[1])
|
||||
mask_ids.append(data[2])
|
||||
segment_ids.append(data[3])
|
||||
return label, text_ids, mask_ids, segment_ids
|
||||
|
||||
|
||||
def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path,
|
||||
data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64):
|
||||
"""Process CMNLI dataset"""
|
||||
### Loading CMNLI from CLUEDataset
|
||||
assert data_usage in ['train', 'eval', 'test']
|
||||
if data_usage == 'train':
|
||||
dataset = ds.CLUEDataset(os.path.join(data_dir, "train.json"), task='CMNLI',
|
||||
usage=data_usage, shuffle=shuffle_dataset)
|
||||
elif data_usage == 'eval':
|
||||
dataset = ds.CLUEDataset(os.path.join(data_dir, "dev.json"), task='CMNLI',
|
||||
usage=data_usage, shuffle=shuffle_dataset)
|
||||
else:
|
||||
dataset = ds.CLUEDataset(os.path.join(data_dir, "test.json"), task='CMNLI',
|
||||
usage=data_usage, shuffle=shuffle_dataset)
|
||||
### Processing label
|
||||
if data_usage == 'test':
|
||||
dataset = dataset.map(input_columns=["id"], output_columns=["id", "label_id"],
|
||||
columns_order=["id", "label_id", "sentence1", "sentence2"], operations=ops.Duplicate())
|
||||
dataset = dataset.map(input_columns=["label_id"], operations=ops.Fill(0))
|
||||
else:
|
||||
label_vocab = text.Vocab.from_list(label_list)
|
||||
label_lookup = text.Lookup(label_vocab)
|
||||
dataset = dataset.map(input_columns="label", output_columns="label_id", operations=label_lookup)
|
||||
### Processing sentence pairs
|
||||
vocab = text.Vocab.from_file(bert_vocab_path)
|
||||
tokenizer = text.BertTokenizer(vocab, lower_case=True)
|
||||
lookup = text.Lookup(vocab, unknown_token='[UNK]')
|
||||
### Tokenizing sentences and truncate sequence pair
|
||||
dataset = dataset.map(input_columns=["sentence1"], operations=tokenizer)
|
||||
dataset = dataset.map(input_columns=["sentence2"], operations=tokenizer)
|
||||
dataset = dataset.map(input_columns=["sentence1", "sentence2"],
|
||||
operations=text.TruncateSequencePair(max_seq_len-3))
|
||||
### Adding special tokens
|
||||
dataset = dataset.map(input_columns=["sentence1"],
|
||||
operations=ops.Concatenate(prepend=np.array(["[CLS]"], dtype='S'),
|
||||
append=np.array(["[SEP]"], dtype='S')))
|
||||
dataset = dataset.map(input_columns=["sentence2"],
|
||||
operations=ops.Concatenate(append=np.array(["[SEP]"], dtype='S')))
|
||||
### Generating segment_ids
|
||||
dataset = dataset.map(input_columns=["sentence1"], output_columns=["sentence1", "type_sentence1"],
|
||||
columns_order=["sentence1", "type_sentence1", "sentence2", "label_id"],
|
||||
operations=ops.Duplicate())
|
||||
dataset = dataset.map(input_columns=["sentence2"], output_columns=["sentence2", "type_sentence2"],
|
||||
columns_order=["sentence1", "type_sentence1", "sentence2", "type_sentence2", "label_id"],
|
||||
operations=ops.Duplicate())
|
||||
dataset = dataset.map(input_columns=["type_sentence1"], operations=[lookup, ops.Fill(0)])
|
||||
dataset = dataset.map(input_columns=["type_sentence2"], operations=[lookup, ops.Fill(1)])
|
||||
dataset = dataset.map(input_columns=["type_sentence1", "type_sentence2"], output_columns=["segment_ids"],
|
||||
columns_order=["sentence1", "sentence2", "segment_ids", "label_id"],
|
||||
operations=ops.Concatenate())
|
||||
dataset = dataset.map(input_columns=["segment_ids"], operations=ops.PadEnd([max_seq_len], 0))
|
||||
### Generating text_ids
|
||||
dataset = dataset.map(input_columns=["sentence1", "sentence2"], output_columns=["text_ids"],
|
||||
columns_order=["text_ids", "segment_ids", "label_id"],
|
||||
operations=ops.Concatenate())
|
||||
dataset = dataset.map(input_columns=["text_ids"], operations=lookup)
|
||||
dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0))
|
||||
### Generating mask_ids
|
||||
dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"],
|
||||
columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate())
|
||||
dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32))
|
||||
dataset = dataset.batch(batch_size)
|
||||
label = []
|
||||
text_ids = []
|
||||
mask_ids = []
|
||||
segment_ids = []
|
||||
for data in dataset:
|
||||
label.append(data[0])
|
||||
text_ids.append(data[1])
|
||||
mask_ids.append(data[2])
|
||||
segment_ids.append(data[3])
|
||||
return label, text_ids, mask_ids, segment_ids
|
Loading…
Reference in New Issue