mindspore/tests/ut/python/dataset/test_datasets_clue.py

420 lines
14 KiB
Python

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import pytest
import mindspore.dataset as ds
def test_clue():
"""
Test CLUE with repeat, skip and so on
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
data = data.repeat(2)
data = data.skip(3)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'label': d['label'].item().decode("utf8"),
'sentence1': d['sentence1'].item().decode("utf8"),
'sentence2': d['sentence2'].item().decode("utf8")
})
assert len(buffer) == 3
def test_clue_num_shards():
"""
Test num_shards param of CLUE dataset
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_shards=3, shard_id=1)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'label': d['label'].item().decode("utf8"),
'sentence1': d['sentence1'].item().decode("utf8"),
'sentence2': d['sentence2'].item().decode("utf8")
})
assert len(buffer) == 1
def test_clue_num_samples():
"""
Test num_samples param of CLUE dataset
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_samples=2)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 2
def test_textline_dataset_get_datasetsize():
"""
Test get_dataset_size of CLUE dataset
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
data = ds.TextFileDataset(TRAIN_FILE)
size = data.get_dataset_size()
assert size == 3
def test_clue_afqmc():
"""
Test AFQMC for train, test and evaluation
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
TEST_FILE = '../data/dataset/testCLUE/afqmc/test.json'
EVAL_FILE = '../data/dataset/testCLUE/afqmc/dev.json'
# train
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'label': d['label'].item().decode("utf8"),
'sentence1': d['sentence1'].item().decode("utf8"),
'sentence2': d['sentence2'].item().decode("utf8")
})
assert len(buffer) == 3
# test
buffer = []
data = ds.CLUEDataset(TEST_FILE, task='AFQMC', usage='test', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'id': d['id'],
'sentence1': d['sentence1'].item().decode("utf8"),
'sentence2': d['sentence2'].item().decode("utf8")
})
assert len(buffer) == 3
# evaluation
buffer = []
data = ds.CLUEDataset(EVAL_FILE, task='AFQMC', usage='eval', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'label': d['label'].item().decode("utf8"),
'sentence1': d['sentence1'].item().decode("utf8"),
'sentence2': d['sentence2'].item().decode("utf8")
})
assert len(buffer) == 3
def test_clue_cmnli():
"""
Test CMNLI for train, test and evaluation
"""
TRAIN_FILE = '../data/dataset/testCLUE/cmnli/train.json'
TEST_FILE = '../data/dataset/testCLUE/cmnli/test.json'
EVAL_FILE = '../data/dataset/testCLUE/cmnli/dev.json'
# train
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='CMNLI', usage='train', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'label': d['label'].item().decode("utf8"),
'sentence1': d['sentence1'].item().decode("utf8"),
'sentence2': d['sentence2'].item().decode("utf8")
})
assert len(buffer) == 3
# test
buffer = []
data = ds.CLUEDataset(TEST_FILE, task='CMNLI', usage='test', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'id': d['id'],
'sentence1': d['sentence1'],
'sentence2': d['sentence2']
})
assert len(buffer) == 3
# eval
buffer = []
data = ds.CLUEDataset(EVAL_FILE, task='CMNLI', usage='eval', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'label': d['label'],
'sentence1': d['sentence1'],
'sentence2': d['sentence2']
})
assert len(buffer) == 3
def test_clue_csl():
"""
Test CSL for train, test and evaluation
"""
TRAIN_FILE = '../data/dataset/testCLUE/csl/train.json'
TEST_FILE = '../data/dataset/testCLUE/csl/test.json'
EVAL_FILE = '../data/dataset/testCLUE/csl/dev.json'
# train
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='CSL', usage='train', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'id': d['id'],
'abst': d['abst'].item().decode("utf8"),
'keyword': [i.item().decode("utf8") for i in d['keyword']],
'label': d['label'].item().decode("utf8")
})
assert len(buffer) == 3
# test
buffer = []
data = ds.CLUEDataset(TEST_FILE, task='CSL', usage='test', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'id': d['id'],
'abst': d['abst'].item().decode("utf8"),
'keyword': [i.item().decode("utf8") for i in d['keyword']],
})
assert len(buffer) == 3
# eval
buffer = []
data = ds.CLUEDataset(EVAL_FILE, task='CSL', usage='eval', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'id': d['id'],
'abst': d['abst'].item().decode("utf8"),
'keyword': [i.item().decode("utf8") for i in d['keyword']],
'label': d['label'].item().decode("utf8")
})
assert len(buffer) == 3
def test_clue_iflytek():
"""
Test IFLYTEK for train, test and evaluation
"""
TRAIN_FILE = '../data/dataset/testCLUE/iflytek/train.json'
TEST_FILE = '../data/dataset/testCLUE/iflytek/test.json'
EVAL_FILE = '../data/dataset/testCLUE/iflytek/dev.json'
# train
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='IFLYTEK', usage='train', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'label': d['label'].item().decode("utf8"),
'label_des': d['label_des'].item().decode("utf8"),
'sentence': d['sentence'].item().decode("utf8"),
})
assert len(buffer) == 3
# test
buffer = []
data = ds.CLUEDataset(TEST_FILE, task='IFLYTEK', usage='test', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'id': d['id'],
'sentence': d['sentence'].item().decode("utf8")
})
assert len(buffer) == 3
# eval
buffer = []
data = ds.CLUEDataset(EVAL_FILE, task='IFLYTEK', usage='eval', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'label': d['label'].item().decode("utf8"),
'label_des': d['label_des'].item().decode("utf8"),
'sentence': d['sentence'].item().decode("utf8")
})
assert len(buffer) == 3
def test_clue_tnews():
"""
Test TNEWS for train, test and evaluation
"""
TRAIN_FILE = '../data/dataset/testCLUE/tnews/train.json'
TEST_FILE = '../data/dataset/testCLUE/tnews/test.json'
EVAL_FILE = '../data/dataset/testCLUE/tnews/dev.json'
# train
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='TNEWS', usage='train', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'label': d['label'].item().decode("utf8"),
'label_desc': d['label_desc'].item().decode("utf8"),
'sentence': d['sentence'].item().decode("utf8"),
'keywords':
d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
})
assert len(buffer) == 3
# test
buffer = []
data = ds.CLUEDataset(TEST_FILE, task='TNEWS', usage='test', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'id': d['id'],
'sentence': d['sentence'].item().decode("utf8"),
'keywords':
d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
})
assert len(buffer) == 3
# eval
buffer = []
data = ds.CLUEDataset(EVAL_FILE, task='TNEWS', usage='eval', shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'label': d['label'].item().decode("utf8"),
'label_desc': d['label_desc'].item().decode("utf8"),
'sentence': d['sentence'].item().decode("utf8"),
'keywords':
d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
})
assert len(buffer) == 3
def test_clue_wsc():
"""
Test WSC for train, test and evaluation
"""
TRAIN_FILE = '../data/dataset/testCLUE/wsc/train.json'
TEST_FILE = '../data/dataset/testCLUE/wsc/test.json'
EVAL_FILE = '../data/dataset/testCLUE/wsc/dev.json'
# train
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='WSC', usage='train')
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'span1_index': d['span1_index'],
'span2_index': d['span2_index'],
'span1_text': d['span1_text'].item().decode("utf8"),
'span2_text': d['span2_text'].item().decode("utf8"),
'idx': d['idx'],
'label': d['label'].item().decode("utf8"),
'text': d['text'].item().decode("utf8")
})
assert len(buffer) == 3
# test
buffer = []
data = ds.CLUEDataset(TEST_FILE, task='WSC', usage='test')
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'span1_index': d['span1_index'],
'span2_index': d['span2_index'],
'span1_text': d['span1_text'].item().decode("utf8"),
'span2_text': d['span2_text'].item().decode("utf8"),
'idx': d['idx'],
'text': d['text'].item().decode("utf8")
})
assert len(buffer) == 3
# eval
buffer = []
data = ds.CLUEDataset(EVAL_FILE, task='WSC', usage='eval')
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append({
'span1_index': d['span1_index'],
'span2_index': d['span2_index'],
'span1_text': d['span1_text'].item().decode("utf8"),
'span2_text': d['span2_text'].item().decode("utf8"),
'idx': d['idx'],
'label': d['label'].item().decode("utf8"),
'text': d['text'].item().decode("utf8")
})
assert len(buffer) == 3
def test_clue_to_device():
"""
Test CLUE with to_device
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
data = data.to_device()
data.send()
def test_clue_invalid_files():
"""
Test CLUE with invalid files
"""
AFQMC_DIR = '../data/dataset/testCLUE/afqmc'
afqmc_train_json = os.path.join(AFQMC_DIR)
with pytest.raises(ValueError) as info:
_ = ds.CLUEDataset(afqmc_train_json, task='AFQMC', usage='train', shuffle=False)
assert "The following patterns did not match any files" in str(info.value)
assert AFQMC_DIR in str(info.value)
def test_clue_exception_file_path():
"""
Test file info in err msg when exception occur of CLUE dataset
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
def exception_func(item):
raise Exception("Error occur!")
try:
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
for _ in data.create_dict_iterator():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
try:
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
data = data.map(operations=exception_func, input_columns=["sentence1"], num_parallel_workers=1)
for _ in data.create_dict_iterator():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
try:
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
data = data.map(operations=exception_func, input_columns=["sentence2"], num_parallel_workers=1)
for _ in data.create_dict_iterator():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
if __name__ == "__main__":
test_clue()
test_clue_num_shards()
test_clue_num_samples()
test_textline_dataset_get_datasetsize()
test_clue_afqmc()
test_clue_cmnli()
test_clue_csl()
test_clue_iflytek()
test_clue_tnews()
test_clue_wsc()
test_clue_to_device()
test_clue_invalid_files()
test_clue_exception_file_path()