fix gis issue
This commit is contained in:
parent
ff4974cca4
commit
f703f66f6c
|
@ -40,7 +40,6 @@ PythonRuntimeContext::~PythonRuntimeContext() {
|
|||
MS_LOG(ERROR) << "Error while terminating the consumer. Message:" << rc;
|
||||
}
|
||||
if (tree_consumer_) {
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
tree_consumer_.reset();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,8 @@ import os
|
|||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from util_minddataset import add_and_remove_cv_file # pylint: disable=unused-import
|
||||
from mindspore import log as logger
|
||||
from util_minddataset import add_and_remove_cv_file, add_and_remove_file # pylint: disable=unused-import
|
||||
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
@ -64,13 +65,13 @@ def test_minddtaset_generatordataset_01(add_and_remove_cv_file):
|
|||
|
||||
dataset = ds.GeneratorDataset(source=MyIterable(data_set, dataset_size),
|
||||
column_names=["data", "file_name", "label"], num_parallel_workers=1)
|
||||
num_epoches = 3
|
||||
iter_ = dataset.create_dict_iterator(num_epochs=3, output_numpy=True)
|
||||
num_epochs = 3
|
||||
iter_ = dataset.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
|
||||
num_iter = 0
|
||||
for _ in range(num_epoches):
|
||||
for _ in range(num_epochs):
|
||||
for _ in iter_:
|
||||
num_iter += 1
|
||||
assert num_iter == num_epoches * dataset_size
|
||||
assert num_iter == num_epochs * dataset_size
|
||||
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
@ -115,11 +116,74 @@ def test_minddtaset_generatordataset_exception_01(add_and_remove_cv_file):
|
|||
|
||||
dataset = ds.GeneratorDataset(source=MyIterable(data_set, dataset_size),
|
||||
column_names=["data", "file_name", "label"], num_parallel_workers=1)
|
||||
num_epoches = 3
|
||||
iter_ = dataset.create_dict_iterator(num_epochs=3, output_numpy=True)
|
||||
num_epochs = 3
|
||||
iter_ = dataset.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
|
||||
num_iter = 0
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
for _ in range(num_epoches):
|
||||
for _ in range(num_epochs):
|
||||
for _ in iter_:
|
||||
num_iter += 1
|
||||
assert 'Unexpected error. Invalid data, column name:' in str(error_info.value)
|
||||
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def test_minddtaset_generatordataset_exception_02(add_and_remove_file):
|
||||
"""
|
||||
Feature: Test basic two level pipeline for mixed dataset.
|
||||
Description: Invalid column name in MindDataset
|
||||
Expectation: Throw expected exception.
|
||||
"""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 1
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
|
||||
file_paths = [file_name + "_cv" + str(i) for i in range(4)]
|
||||
file_paths += [file_name + "_nlp" + str(i) for i in range(4)]
|
||||
class MyIterable:
|
||||
""" custom iteration """
|
||||
def __init__(self, file_paths):
|
||||
self._iter = None
|
||||
self._index = 0
|
||||
self._idx = 0
|
||||
self._file_paths = file_paths
|
||||
|
||||
def __next__(self):
|
||||
if self._index >= len(self._file_paths) * 10:
|
||||
raise StopIteration
|
||||
if self._iter:
|
||||
try:
|
||||
item = next(self._iter)
|
||||
self._index += 1
|
||||
except StopIteration:
|
||||
if self._idx >= len(self._file_paths):
|
||||
raise StopIteration
|
||||
self._iter = None
|
||||
return next(self)
|
||||
return item
|
||||
logger.info("load <<< {}.".format(self._file_paths[self._idx]))
|
||||
self._iter = ds.MindDataset(self._file_paths[self._idx],
|
||||
columns_list, num_parallel_workers=num_readers,
|
||||
shuffle=None).create_tuple_iterator(num_epochs=1, output_numpy=True)
|
||||
self._idx += 1
|
||||
return next(self)
|
||||
|
||||
def __iter__(self):
|
||||
self._index = 0
|
||||
self._idx = 0
|
||||
self._iter = None
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return len(self._file_paths) * 10
|
||||
|
||||
dataset = ds.GeneratorDataset(source=MyIterable(file_paths),
|
||||
column_names=["data", "file_name", "label"], num_parallel_workers=1)
|
||||
num_epochs = 1
|
||||
iter_ = dataset.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
|
||||
num_iter = 0
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
for _ in range(num_epochs):
|
||||
for item in iter_:
|
||||
print("item: ", item)
|
||||
num_iter += 1
|
||||
assert 'Unexpected error. Invalid data, column name:' in str(error_info.value)
|
||||
|
|
|
@ -16,7 +16,13 @@
|
|||
This module contains common utility functions for minddataset tests.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import collections
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
FILES_NUM = 4
|
||||
|
@ -54,6 +60,87 @@ def get_data(dir_name):
|
|||
return data_list
|
||||
|
||||
|
||||
def inputs(vectors, maxlen=50):
|
||||
length = len(vectors)
|
||||
if length > maxlen:
|
||||
return vectors[0:maxlen], [1] * maxlen, [0] * maxlen
|
||||
input_ = vectors + [0] * (maxlen - length)
|
||||
mask = [1] * length + [0] * (maxlen - length)
|
||||
segment = [0] * maxlen
|
||||
return input_, mask, segment
|
||||
|
||||
|
||||
def convert_to_uni(text):
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
if isinstance(text, bytes):
|
||||
return text.decode('utf-8', 'ignore')
|
||||
raise Exception("The type %s does not convert!" % type(text))
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""load vocabulary to translate statement."""
|
||||
vocab = collections.OrderedDict()
|
||||
vocab.setdefault('blank', 2)
|
||||
index = 0
|
||||
with open(vocab_file) as reader:
|
||||
while True:
|
||||
tmp = reader.readline()
|
||||
if not tmp:
|
||||
break
|
||||
token = convert_to_uni(tmp)
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def get_nlp_data(dir_name, vocab_file, num):
|
||||
"""
|
||||
Return raw data of aclImdb dataset.
|
||||
|
||||
Args:
|
||||
dir_name (str): String of aclImdb dataset's path.
|
||||
vocab_file (str): String of dictionary's path.
|
||||
num (int): Number of sample.
|
||||
|
||||
Returns:
|
||||
List
|
||||
"""
|
||||
if not os.path.isdir(dir_name):
|
||||
raise IOError("Directory {} not exists".format(dir_name))
|
||||
for root, _, files in os.walk(dir_name):
|
||||
for index, file_name_extension in enumerate(files):
|
||||
if index < num:
|
||||
file_path = os.path.join(root, file_name_extension)
|
||||
file_name, _ = file_name_extension.split('.', 1)
|
||||
id_, rating = file_name.split('_', 1)
|
||||
with open(file_path, 'r') as f:
|
||||
raw_content = f.read()
|
||||
|
||||
dictionary = load_vocab(vocab_file)
|
||||
vectors = [dictionary.get('[CLS]')]
|
||||
vectors += [dictionary.get(i) if i in dictionary
|
||||
else dictionary.get('[UNK]')
|
||||
for i in re.findall(r"[\w']+|[{}]"
|
||||
.format(string.punctuation),
|
||||
raw_content)]
|
||||
vectors += [dictionary.get('[SEP]')]
|
||||
input_, mask, segment = inputs(vectors)
|
||||
input_ids = np.reshape(np.array(input_), [-1])
|
||||
input_mask = np.reshape(np.array(mask), [1, -1])
|
||||
segment_ids = np.reshape(np.array(segment), [2, -1])
|
||||
data = {
|
||||
"label": 1,
|
||||
"id": id_,
|
||||
"rating": float(rating),
|
||||
"input_ids": input_ids,
|
||||
"input_mask": input_mask,
|
||||
"segment_ids": segment_ids
|
||||
}
|
||||
yield data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def add_and_remove_cv_file():
|
||||
"""add/remove cv file"""
|
||||
|
@ -86,3 +173,58 @@ def add_and_remove_cv_file():
|
|||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def add_and_remove_file():
|
||||
"""add/remove file"""
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
paths = ["{}{}".format(file_name + "_cv", str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
paths += ["{}{}".format(file_name + "_nlp", str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
try:
|
||||
for x in paths:
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
writer = FileWriter(file_name + "_nlp", FILES_NUM)
|
||||
data = list(get_nlp_data("../data/mindrecord/testAclImdbData/pos",
|
||||
"../data/mindrecord/testAclImdbData/vocab.txt",
|
||||
10))
|
||||
nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"},
|
||||
"rating": {"type": "float32"},
|
||||
"input_ids": {"type": "int64",
|
||||
"shape": [1, -1]},
|
||||
"input_mask": {"type": "int64",
|
||||
"shape": [1, -1]},
|
||||
"segment_ids": {"type": "int64",
|
||||
"shape": [1, -1]}
|
||||
}
|
||||
writer.add_schema(nlp_schema_json, "nlp_schema")
|
||||
writer.add_index(["id", "rating"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
writer = FileWriter(file_name + "_cv", FILES_NUM)
|
||||
data = get_data(CV_DIR_NAME)
|
||||
cv_schema_json = {"id": {"type": "int32"},
|
||||
"file_name": {"type": "string"},
|
||||
"label": {"type": "int32"},
|
||||
"data": {"type": "bytes"}}
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.add_index(["file_name", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
yield "yield_data"
|
||||
except Exception as error:
|
||||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
raise error
|
||||
else:
|
||||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
|
Loading…
Reference in New Issue