forked from mindspore-Ecosystem/mindspore
420 lines
14 KiB
Python
420 lines
14 KiB
Python
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
import os
|
|
import pytest
|
|
import mindspore.dataset as ds
|
|
|
|
|
|
def test_clue():
|
|
"""
|
|
Test CLUE with repeat, skip and so on
|
|
"""
|
|
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
|
|
|
buffer = []
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
|
|
data = data.repeat(2)
|
|
data = data.skip(3)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'label': d['label'].item().decode("utf8"),
|
|
'sentence1': d['sentence1'].item().decode("utf8"),
|
|
'sentence2': d['sentence2'].item().decode("utf8")
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
|
|
def test_clue_num_shards():
|
|
"""
|
|
Test num_shards param of CLUE dataset
|
|
"""
|
|
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
|
|
|
buffer = []
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_shards=3, shard_id=1)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'label': d['label'].item().decode("utf8"),
|
|
'sentence1': d['sentence1'].item().decode("utf8"),
|
|
'sentence2': d['sentence2'].item().decode("utf8")
|
|
})
|
|
assert len(buffer) == 1
|
|
|
|
|
|
def test_clue_num_samples():
|
|
"""
|
|
Test num_samples param of CLUE dataset
|
|
"""
|
|
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
|
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_samples=2)
|
|
count = 0
|
|
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
count += 1
|
|
assert count == 2
|
|
|
|
|
|
def test_textline_dataset_get_datasetsize():
|
|
"""
|
|
Test get_dataset_size of CLUE dataset
|
|
"""
|
|
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
|
|
|
data = ds.TextFileDataset(TRAIN_FILE)
|
|
size = data.get_dataset_size()
|
|
assert size == 3
|
|
|
|
|
|
def test_clue_afqmc():
|
|
"""
|
|
Test AFQMC for train, test and evaluation
|
|
"""
|
|
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
|
TEST_FILE = '../data/dataset/testCLUE/afqmc/test.json'
|
|
EVAL_FILE = '../data/dataset/testCLUE/afqmc/dev.json'
|
|
|
|
# train
|
|
buffer = []
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'label': d['label'].item().decode("utf8"),
|
|
'sentence1': d['sentence1'].item().decode("utf8"),
|
|
'sentence2': d['sentence2'].item().decode("utf8")
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
# test
|
|
buffer = []
|
|
data = ds.CLUEDataset(TEST_FILE, task='AFQMC', usage='test', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'id': d['id'],
|
|
'sentence1': d['sentence1'].item().decode("utf8"),
|
|
'sentence2': d['sentence2'].item().decode("utf8")
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
# evaluation
|
|
buffer = []
|
|
data = ds.CLUEDataset(EVAL_FILE, task='AFQMC', usage='eval', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'label': d['label'].item().decode("utf8"),
|
|
'sentence1': d['sentence1'].item().decode("utf8"),
|
|
'sentence2': d['sentence2'].item().decode("utf8")
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
|
|
def test_clue_cmnli():
|
|
"""
|
|
Test CMNLI for train, test and evaluation
|
|
"""
|
|
TRAIN_FILE = '../data/dataset/testCLUE/cmnli/train.json'
|
|
TEST_FILE = '../data/dataset/testCLUE/cmnli/test.json'
|
|
EVAL_FILE = '../data/dataset/testCLUE/cmnli/dev.json'
|
|
|
|
# train
|
|
buffer = []
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='CMNLI', usage='train', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'label': d['label'].item().decode("utf8"),
|
|
'sentence1': d['sentence1'].item().decode("utf8"),
|
|
'sentence2': d['sentence2'].item().decode("utf8")
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
# test
|
|
buffer = []
|
|
data = ds.CLUEDataset(TEST_FILE, task='CMNLI', usage='test', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'id': d['id'],
|
|
'sentence1': d['sentence1'],
|
|
'sentence2': d['sentence2']
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
# eval
|
|
buffer = []
|
|
data = ds.CLUEDataset(EVAL_FILE, task='CMNLI', usage='eval', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'label': d['label'],
|
|
'sentence1': d['sentence1'],
|
|
'sentence2': d['sentence2']
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
|
|
def test_clue_csl():
|
|
"""
|
|
Test CSL for train, test and evaluation
|
|
"""
|
|
TRAIN_FILE = '../data/dataset/testCLUE/csl/train.json'
|
|
TEST_FILE = '../data/dataset/testCLUE/csl/test.json'
|
|
EVAL_FILE = '../data/dataset/testCLUE/csl/dev.json'
|
|
|
|
# train
|
|
buffer = []
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='CSL', usage='train', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'id': d['id'],
|
|
'abst': d['abst'].item().decode("utf8"),
|
|
'keyword': [i.item().decode("utf8") for i in d['keyword']],
|
|
'label': d['label'].item().decode("utf8")
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
# test
|
|
buffer = []
|
|
data = ds.CLUEDataset(TEST_FILE, task='CSL', usage='test', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'id': d['id'],
|
|
'abst': d['abst'].item().decode("utf8"),
|
|
'keyword': [i.item().decode("utf8") for i in d['keyword']],
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
# eval
|
|
buffer = []
|
|
data = ds.CLUEDataset(EVAL_FILE, task='CSL', usage='eval', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'id': d['id'],
|
|
'abst': d['abst'].item().decode("utf8"),
|
|
'keyword': [i.item().decode("utf8") for i in d['keyword']],
|
|
'label': d['label'].item().decode("utf8")
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
|
|
def test_clue_iflytek():
|
|
"""
|
|
Test IFLYTEK for train, test and evaluation
|
|
"""
|
|
TRAIN_FILE = '../data/dataset/testCLUE/iflytek/train.json'
|
|
TEST_FILE = '../data/dataset/testCLUE/iflytek/test.json'
|
|
EVAL_FILE = '../data/dataset/testCLUE/iflytek/dev.json'
|
|
|
|
# train
|
|
buffer = []
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='IFLYTEK', usage='train', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'label': d['label'].item().decode("utf8"),
|
|
'label_des': d['label_des'].item().decode("utf8"),
|
|
'sentence': d['sentence'].item().decode("utf8"),
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
# test
|
|
buffer = []
|
|
data = ds.CLUEDataset(TEST_FILE, task='IFLYTEK', usage='test', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'id': d['id'],
|
|
'sentence': d['sentence'].item().decode("utf8")
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
# eval
|
|
buffer = []
|
|
data = ds.CLUEDataset(EVAL_FILE, task='IFLYTEK', usage='eval', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'label': d['label'].item().decode("utf8"),
|
|
'label_des': d['label_des'].item().decode("utf8"),
|
|
'sentence': d['sentence'].item().decode("utf8")
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
|
|
def test_clue_tnews():
|
|
"""
|
|
Test TNEWS for train, test and evaluation
|
|
"""
|
|
TRAIN_FILE = '../data/dataset/testCLUE/tnews/train.json'
|
|
TEST_FILE = '../data/dataset/testCLUE/tnews/test.json'
|
|
EVAL_FILE = '../data/dataset/testCLUE/tnews/dev.json'
|
|
|
|
# train
|
|
buffer = []
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='TNEWS', usage='train', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'label': d['label'].item().decode("utf8"),
|
|
'label_desc': d['label_desc'].item().decode("utf8"),
|
|
'sentence': d['sentence'].item().decode("utf8"),
|
|
'keywords':
|
|
d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
# test
|
|
buffer = []
|
|
data = ds.CLUEDataset(TEST_FILE, task='TNEWS', usage='test', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'id': d['id'],
|
|
'sentence': d['sentence'].item().decode("utf8"),
|
|
'keywords':
|
|
d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
# eval
|
|
buffer = []
|
|
data = ds.CLUEDataset(EVAL_FILE, task='TNEWS', usage='eval', shuffle=False)
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'label': d['label'].item().decode("utf8"),
|
|
'label_desc': d['label_desc'].item().decode("utf8"),
|
|
'sentence': d['sentence'].item().decode("utf8"),
|
|
'keywords':
|
|
d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
|
|
def test_clue_wsc():
|
|
"""
|
|
Test WSC for train, test and evaluation
|
|
"""
|
|
TRAIN_FILE = '../data/dataset/testCLUE/wsc/train.json'
|
|
TEST_FILE = '../data/dataset/testCLUE/wsc/test.json'
|
|
EVAL_FILE = '../data/dataset/testCLUE/wsc/dev.json'
|
|
|
|
# train
|
|
buffer = []
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='WSC', usage='train')
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'span1_index': d['span1_index'],
|
|
'span2_index': d['span2_index'],
|
|
'span1_text': d['span1_text'].item().decode("utf8"),
|
|
'span2_text': d['span2_text'].item().decode("utf8"),
|
|
'idx': d['idx'],
|
|
'label': d['label'].item().decode("utf8"),
|
|
'text': d['text'].item().decode("utf8")
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
# test
|
|
buffer = []
|
|
data = ds.CLUEDataset(TEST_FILE, task='WSC', usage='test')
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'span1_index': d['span1_index'],
|
|
'span2_index': d['span2_index'],
|
|
'span1_text': d['span1_text'].item().decode("utf8"),
|
|
'span2_text': d['span2_text'].item().decode("utf8"),
|
|
'idx': d['idx'],
|
|
'text': d['text'].item().decode("utf8")
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
# eval
|
|
buffer = []
|
|
data = ds.CLUEDataset(EVAL_FILE, task='WSC', usage='eval')
|
|
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
buffer.append({
|
|
'span1_index': d['span1_index'],
|
|
'span2_index': d['span2_index'],
|
|
'span1_text': d['span1_text'].item().decode("utf8"),
|
|
'span2_text': d['span2_text'].item().decode("utf8"),
|
|
'idx': d['idx'],
|
|
'label': d['label'].item().decode("utf8"),
|
|
'text': d['text'].item().decode("utf8")
|
|
})
|
|
assert len(buffer) == 3
|
|
|
|
def test_clue_to_device():
|
|
"""
|
|
Test CLUE with to_device
|
|
"""
|
|
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
|
|
data = data.to_device()
|
|
data.send()
|
|
|
|
|
|
def test_clue_invalid_files():
|
|
"""
|
|
Test CLUE with invalid files
|
|
"""
|
|
AFQMC_DIR = '../data/dataset/testCLUE/afqmc'
|
|
afqmc_train_json = os.path.join(AFQMC_DIR)
|
|
with pytest.raises(ValueError) as info:
|
|
_ = ds.CLUEDataset(afqmc_train_json, task='AFQMC', usage='train', shuffle=False)
|
|
assert "The following patterns did not match any files" in str(info.value)
|
|
assert AFQMC_DIR in str(info.value)
|
|
|
|
|
|
def test_clue_exception_file_path():
|
|
"""
|
|
Test file info in err msg when exception occur of CLUE dataset
|
|
"""
|
|
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
|
def exception_func(item):
|
|
raise Exception("Error occur!")
|
|
|
|
try:
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
|
|
data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
|
|
for _ in data.create_dict_iterator():
|
|
pass
|
|
assert False
|
|
except RuntimeError as e:
|
|
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
|
|
|
try:
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
|
|
data = data.map(operations=exception_func, input_columns=["sentence1"], num_parallel_workers=1)
|
|
for _ in data.create_dict_iterator():
|
|
pass
|
|
assert False
|
|
except RuntimeError as e:
|
|
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
|
|
|
try:
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
|
|
data = data.map(operations=exception_func, input_columns=["sentence2"], num_parallel_workers=1)
|
|
for _ in data.create_dict_iterator():
|
|
pass
|
|
assert False
|
|
except RuntimeError as e:
|
|
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_clue()
|
|
test_clue_num_shards()
|
|
test_clue_num_samples()
|
|
test_textline_dataset_get_datasetsize()
|
|
test_clue_afqmc()
|
|
test_clue_cmnli()
|
|
test_clue_csl()
|
|
test_clue_iflytek()
|
|
test_clue_tnews()
|
|
test_clue_wsc()
|
|
test_clue_to_device()
|
|
test_clue_invalid_files()
|
|
test_clue_exception_file_path()
|