fix gis issue

This commit is contained in:
liyong 2022-03-26 14:24:16 +08:00
parent ff4974cca4
commit f703f66f6c
3 changed files with 214 additions and 9 deletions

View File

@ -40,7 +40,6 @@ PythonRuntimeContext::~PythonRuntimeContext() {
MS_LOG(ERROR) << "Error while terminating the consumer. Message:" << rc; MS_LOG(ERROR) << "Error while terminating the consumer. Message:" << rc;
} }
if (tree_consumer_) { if (tree_consumer_) {
py::gil_scoped_acquire gil_acquire;
tree_consumer_.reset(); tree_consumer_.reset();
} }
} }

View File

@ -19,7 +19,8 @@ import os
import pytest import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
from util_minddataset import add_and_remove_cv_file # pylint: disable=unused-import from mindspore import log as logger
from util_minddataset import add_and_remove_cv_file, add_and_remove_file # pylint: disable=unused-import
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
@ -64,13 +65,13 @@ def test_minddtaset_generatordataset_01(add_and_remove_cv_file):
dataset = ds.GeneratorDataset(source=MyIterable(data_set, dataset_size), dataset = ds.GeneratorDataset(source=MyIterable(data_set, dataset_size),
column_names=["data", "file_name", "label"], num_parallel_workers=1) column_names=["data", "file_name", "label"], num_parallel_workers=1)
num_epoches = 3 num_epochs = 3
iter_ = dataset.create_dict_iterator(num_epochs=3, output_numpy=True) iter_ = dataset.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
num_iter = 0 num_iter = 0
for _ in range(num_epoches): for _ in range(num_epochs):
for _ in iter_: for _ in iter_:
num_iter += 1 num_iter += 1
assert num_iter == num_epoches * dataset_size assert num_iter == num_epochs * dataset_size
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
@ -115,11 +116,74 @@ def test_minddtaset_generatordataset_exception_01(add_and_remove_cv_file):
dataset = ds.GeneratorDataset(source=MyIterable(data_set, dataset_size), dataset = ds.GeneratorDataset(source=MyIterable(data_set, dataset_size),
column_names=["data", "file_name", "label"], num_parallel_workers=1) column_names=["data", "file_name", "label"], num_parallel_workers=1)
num_epoches = 3 num_epochs = 3
iter_ = dataset.create_dict_iterator(num_epochs=3, output_numpy=True) iter_ = dataset.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
num_iter = 0 num_iter = 0
with pytest.raises(RuntimeError) as error_info: with pytest.raises(RuntimeError) as error_info:
for _ in range(num_epoches): for _ in range(num_epochs):
for _ in iter_: for _ in iter_:
num_iter += 1 num_iter += 1
assert 'Unexpected error. Invalid data, column name:' in str(error_info.value) assert 'Unexpected error. Invalid data, column name:' in str(error_info.value)
# pylint: disable=redefined-outer-name
def test_minddtaset_generatordataset_exception_02(add_and_remove_file):
"""
Feature: Test basic two level pipeline for mixed dataset.
Description: Invalid column name in MindDataset
Expectation: Throw expected exception.
"""
columns_list = ["data", "file_name", "label"]
num_readers = 1
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
file_paths = [file_name + "_cv" + str(i) for i in range(4)]
file_paths += [file_name + "_nlp" + str(i) for i in range(4)]
class MyIterable:
""" custom iteration """
def __init__(self, file_paths):
self._iter = None
self._index = 0
self._idx = 0
self._file_paths = file_paths
def __next__(self):
if self._index >= len(self._file_paths) * 10:
raise StopIteration
if self._iter:
try:
item = next(self._iter)
self._index += 1
except StopIteration:
if self._idx >= len(self._file_paths):
raise StopIteration
self._iter = None
return next(self)
return item
logger.info("load <<< {}.".format(self._file_paths[self._idx]))
self._iter = ds.MindDataset(self._file_paths[self._idx],
columns_list, num_parallel_workers=num_readers,
shuffle=None).create_tuple_iterator(num_epochs=1, output_numpy=True)
self._idx += 1
return next(self)
def __iter__(self):
self._index = 0
self._idx = 0
self._iter = None
return self
def __len__(self):
return len(self._file_paths) * 10
dataset = ds.GeneratorDataset(source=MyIterable(file_paths),
column_names=["data", "file_name", "label"], num_parallel_workers=1)
num_epochs = 1
iter_ = dataset.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
num_iter = 0
with pytest.raises(RuntimeError) as error_info:
for _ in range(num_epochs):
for item in iter_:
print("item: ", item)
num_iter += 1
assert 'Unexpected error. Invalid data, column name:' in str(error_info.value)

View File

@ -16,7 +16,13 @@
This module contains common utility functions for minddataset tests. This module contains common utility functions for minddataset tests.
""" """
import os import os
import re
import string
import collections
import pytest import pytest
import numpy as np
from mindspore.mindrecord import FileWriter from mindspore.mindrecord import FileWriter
FILES_NUM = 4 FILES_NUM = 4
@ -54,6 +60,87 @@ def get_data(dir_name):
return data_list return data_list
def inputs(vectors, maxlen=50):
length = len(vectors)
if length > maxlen:
return vectors[0:maxlen], [1] * maxlen, [0] * maxlen
input_ = vectors + [0] * (maxlen - length)
mask = [1] * length + [0] * (maxlen - length)
segment = [0] * maxlen
return input_, mask, segment
def convert_to_uni(text):
if isinstance(text, str):
return text
if isinstance(text, bytes):
return text.decode('utf-8', 'ignore')
raise Exception("The type %s does not convert!" % type(text))
def load_vocab(vocab_file):
"""load vocabulary to translate statement."""
vocab = collections.OrderedDict()
vocab.setdefault('blank', 2)
index = 0
with open(vocab_file) as reader:
while True:
tmp = reader.readline()
if not tmp:
break
token = convert_to_uni(tmp)
token = token.strip()
vocab[token] = index
index += 1
return vocab
def get_nlp_data(dir_name, vocab_file, num):
"""
Return raw data of aclImdb dataset.
Args:
dir_name (str): String of aclImdb dataset's path.
vocab_file (str): String of dictionary's path.
num (int): Number of sample.
Returns:
List
"""
if not os.path.isdir(dir_name):
raise IOError("Directory {} not exists".format(dir_name))
for root, _, files in os.walk(dir_name):
for index, file_name_extension in enumerate(files):
if index < num:
file_path = os.path.join(root, file_name_extension)
file_name, _ = file_name_extension.split('.', 1)
id_, rating = file_name.split('_', 1)
with open(file_path, 'r') as f:
raw_content = f.read()
dictionary = load_vocab(vocab_file)
vectors = [dictionary.get('[CLS]')]
vectors += [dictionary.get(i) if i in dictionary
else dictionary.get('[UNK]')
for i in re.findall(r"[\w']+|[{}]"
.format(string.punctuation),
raw_content)]
vectors += [dictionary.get('[SEP]')]
input_, mask, segment = inputs(vectors)
input_ids = np.reshape(np.array(input_), [-1])
input_mask = np.reshape(np.array(mask), [1, -1])
segment_ids = np.reshape(np.array(segment), [2, -1])
data = {
"label": 1,
"id": id_,
"rating": float(rating),
"input_ids": input_ids,
"input_mask": input_mask,
"segment_ids": segment_ids
}
yield data
@pytest.fixture @pytest.fixture
def add_and_remove_cv_file(): def add_and_remove_cv_file():
"""add/remove cv file""" """add/remove cv file"""
@ -86,3 +173,58 @@ def add_and_remove_cv_file():
for x in paths: for x in paths:
os.remove("{}".format(x)) os.remove("{}".format(x))
os.remove("{}.db".format(x)) os.remove("{}.db".format(x))
@pytest.fixture
def add_and_remove_file():
"""add/remove file"""
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
paths = ["{}{}".format(file_name + "_cv", str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
paths += ["{}{}".format(file_name + "_nlp", str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
try:
for x in paths:
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
writer = FileWriter(file_name + "_nlp", FILES_NUM)
data = list(get_nlp_data("../data/mindrecord/testAclImdbData/pos",
"../data/mindrecord/testAclImdbData/vocab.txt",
10))
nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"},
"rating": {"type": "float32"},
"input_ids": {"type": "int64",
"shape": [1, -1]},
"input_mask": {"type": "int64",
"shape": [1, -1]},
"segment_ids": {"type": "int64",
"shape": [1, -1]}
}
writer.add_schema(nlp_schema_json, "nlp_schema")
writer.add_index(["id", "rating"])
writer.write_raw_data(data)
writer.commit()
writer = FileWriter(file_name + "_cv", FILES_NUM)
data = get_data(CV_DIR_NAME)
cv_schema_json = {"id": {"type": "int32"},
"file_name": {"type": "string"},
"label": {"type": "int32"},
"data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
writer.write_raw_data(data)
writer.commit()
yield "yield_data"
except Exception as error:
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
raise error
else:
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))