fix gis issue
This commit is contained in:
parent
ff4974cca4
commit
f703f66f6c
|
@ -40,7 +40,6 @@ PythonRuntimeContext::~PythonRuntimeContext() {
|
||||||
MS_LOG(ERROR) << "Error while terminating the consumer. Message:" << rc;
|
MS_LOG(ERROR) << "Error while terminating the consumer. Message:" << rc;
|
||||||
}
|
}
|
||||||
if (tree_consumer_) {
|
if (tree_consumer_) {
|
||||||
py::gil_scoped_acquire gil_acquire;
|
|
||||||
tree_consumer_.reset();
|
tree_consumer_.reset();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,8 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
from util_minddataset import add_and_remove_cv_file # pylint: disable=unused-import
|
from mindspore import log as logger
|
||||||
|
from util_minddataset import add_and_remove_cv_file, add_and_remove_file # pylint: disable=unused-import
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
@ -64,13 +65,13 @@ def test_minddtaset_generatordataset_01(add_and_remove_cv_file):
|
||||||
|
|
||||||
dataset = ds.GeneratorDataset(source=MyIterable(data_set, dataset_size),
|
dataset = ds.GeneratorDataset(source=MyIterable(data_set, dataset_size),
|
||||||
column_names=["data", "file_name", "label"], num_parallel_workers=1)
|
column_names=["data", "file_name", "label"], num_parallel_workers=1)
|
||||||
num_epoches = 3
|
num_epochs = 3
|
||||||
iter_ = dataset.create_dict_iterator(num_epochs=3, output_numpy=True)
|
iter_ = dataset.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for _ in range(num_epoches):
|
for _ in range(num_epochs):
|
||||||
for _ in iter_:
|
for _ in iter_:
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert num_iter == num_epoches * dataset_size
|
assert num_iter == num_epochs * dataset_size
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
@ -115,11 +116,74 @@ def test_minddtaset_generatordataset_exception_01(add_and_remove_cv_file):
|
||||||
|
|
||||||
dataset = ds.GeneratorDataset(source=MyIterable(data_set, dataset_size),
|
dataset = ds.GeneratorDataset(source=MyIterable(data_set, dataset_size),
|
||||||
column_names=["data", "file_name", "label"], num_parallel_workers=1)
|
column_names=["data", "file_name", "label"], num_parallel_workers=1)
|
||||||
num_epoches = 3
|
num_epochs = 3
|
||||||
iter_ = dataset.create_dict_iterator(num_epochs=3, output_numpy=True)
|
iter_ = dataset.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
with pytest.raises(RuntimeError) as error_info:
|
with pytest.raises(RuntimeError) as error_info:
|
||||||
for _ in range(num_epoches):
|
for _ in range(num_epochs):
|
||||||
for _ in iter_:
|
for _ in iter_:
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert 'Unexpected error. Invalid data, column name:' in str(error_info.value)
|
assert 'Unexpected error. Invalid data, column name:' in str(error_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
def test_minddtaset_generatordataset_exception_02(add_and_remove_file):
|
||||||
|
"""
|
||||||
|
Feature: Test basic two level pipeline for mixed dataset.
|
||||||
|
Description: Invalid column name in MindDataset
|
||||||
|
Expectation: Throw expected exception.
|
||||||
|
"""
|
||||||
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
num_readers = 1
|
||||||
|
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||||
|
|
||||||
|
file_paths = [file_name + "_cv" + str(i) for i in range(4)]
|
||||||
|
file_paths += [file_name + "_nlp" + str(i) for i in range(4)]
|
||||||
|
class MyIterable:
|
||||||
|
""" custom iteration """
|
||||||
|
def __init__(self, file_paths):
|
||||||
|
self._iter = None
|
||||||
|
self._index = 0
|
||||||
|
self._idx = 0
|
||||||
|
self._file_paths = file_paths
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self._index >= len(self._file_paths) * 10:
|
||||||
|
raise StopIteration
|
||||||
|
if self._iter:
|
||||||
|
try:
|
||||||
|
item = next(self._iter)
|
||||||
|
self._index += 1
|
||||||
|
except StopIteration:
|
||||||
|
if self._idx >= len(self._file_paths):
|
||||||
|
raise StopIteration
|
||||||
|
self._iter = None
|
||||||
|
return next(self)
|
||||||
|
return item
|
||||||
|
logger.info("load <<< {}.".format(self._file_paths[self._idx]))
|
||||||
|
self._iter = ds.MindDataset(self._file_paths[self._idx],
|
||||||
|
columns_list, num_parallel_workers=num_readers,
|
||||||
|
shuffle=None).create_tuple_iterator(num_epochs=1, output_numpy=True)
|
||||||
|
self._idx += 1
|
||||||
|
return next(self)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self._index = 0
|
||||||
|
self._idx = 0
|
||||||
|
self._iter = None
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._file_paths) * 10
|
||||||
|
|
||||||
|
dataset = ds.GeneratorDataset(source=MyIterable(file_paths),
|
||||||
|
column_names=["data", "file_name", "label"], num_parallel_workers=1)
|
||||||
|
num_epochs = 1
|
||||||
|
iter_ = dataset.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
|
||||||
|
num_iter = 0
|
||||||
|
with pytest.raises(RuntimeError) as error_info:
|
||||||
|
for _ in range(num_epochs):
|
||||||
|
for item in iter_:
|
||||||
|
print("item: ", item)
|
||||||
|
num_iter += 1
|
||||||
|
assert 'Unexpected error. Invalid data, column name:' in str(error_info.value)
|
||||||
|
|
|
@ -16,7 +16,13 @@
|
||||||
This module contains common utility functions for minddataset tests.
|
This module contains common utility functions for minddataset tests.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import string
|
||||||
|
import collections
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from mindspore.mindrecord import FileWriter
|
from mindspore.mindrecord import FileWriter
|
||||||
|
|
||||||
FILES_NUM = 4
|
FILES_NUM = 4
|
||||||
|
@ -54,6 +60,87 @@ def get_data(dir_name):
|
||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
|
def inputs(vectors, maxlen=50):
|
||||||
|
length = len(vectors)
|
||||||
|
if length > maxlen:
|
||||||
|
return vectors[0:maxlen], [1] * maxlen, [0] * maxlen
|
||||||
|
input_ = vectors + [0] * (maxlen - length)
|
||||||
|
mask = [1] * length + [0] * (maxlen - length)
|
||||||
|
segment = [0] * maxlen
|
||||||
|
return input_, mask, segment
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_uni(text):
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text
|
||||||
|
if isinstance(text, bytes):
|
||||||
|
return text.decode('utf-8', 'ignore')
|
||||||
|
raise Exception("The type %s does not convert!" % type(text))
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocab(vocab_file):
|
||||||
|
"""load vocabulary to translate statement."""
|
||||||
|
vocab = collections.OrderedDict()
|
||||||
|
vocab.setdefault('blank', 2)
|
||||||
|
index = 0
|
||||||
|
with open(vocab_file) as reader:
|
||||||
|
while True:
|
||||||
|
tmp = reader.readline()
|
||||||
|
if not tmp:
|
||||||
|
break
|
||||||
|
token = convert_to_uni(tmp)
|
||||||
|
token = token.strip()
|
||||||
|
vocab[token] = index
|
||||||
|
index += 1
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
def get_nlp_data(dir_name, vocab_file, num):
|
||||||
|
"""
|
||||||
|
Return raw data of aclImdb dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dir_name (str): String of aclImdb dataset's path.
|
||||||
|
vocab_file (str): String of dictionary's path.
|
||||||
|
num (int): Number of sample.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List
|
||||||
|
"""
|
||||||
|
if not os.path.isdir(dir_name):
|
||||||
|
raise IOError("Directory {} not exists".format(dir_name))
|
||||||
|
for root, _, files in os.walk(dir_name):
|
||||||
|
for index, file_name_extension in enumerate(files):
|
||||||
|
if index < num:
|
||||||
|
file_path = os.path.join(root, file_name_extension)
|
||||||
|
file_name, _ = file_name_extension.split('.', 1)
|
||||||
|
id_, rating = file_name.split('_', 1)
|
||||||
|
with open(file_path, 'r') as f:
|
||||||
|
raw_content = f.read()
|
||||||
|
|
||||||
|
dictionary = load_vocab(vocab_file)
|
||||||
|
vectors = [dictionary.get('[CLS]')]
|
||||||
|
vectors += [dictionary.get(i) if i in dictionary
|
||||||
|
else dictionary.get('[UNK]')
|
||||||
|
for i in re.findall(r"[\w']+|[{}]"
|
||||||
|
.format(string.punctuation),
|
||||||
|
raw_content)]
|
||||||
|
vectors += [dictionary.get('[SEP]')]
|
||||||
|
input_, mask, segment = inputs(vectors)
|
||||||
|
input_ids = np.reshape(np.array(input_), [-1])
|
||||||
|
input_mask = np.reshape(np.array(mask), [1, -1])
|
||||||
|
segment_ids = np.reshape(np.array(segment), [2, -1])
|
||||||
|
data = {
|
||||||
|
"label": 1,
|
||||||
|
"id": id_,
|
||||||
|
"rating": float(rating),
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"input_mask": input_mask,
|
||||||
|
"segment_ids": segment_ids
|
||||||
|
}
|
||||||
|
yield data
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def add_and_remove_cv_file():
|
def add_and_remove_cv_file():
|
||||||
"""add/remove cv file"""
|
"""add/remove cv file"""
|
||||||
|
@ -86,3 +173,58 @@ def add_and_remove_cv_file():
|
||||||
for x in paths:
|
for x in paths:
|
||||||
os.remove("{}".format(x))
|
os.remove("{}".format(x))
|
||||||
os.remove("{}.db".format(x))
|
os.remove("{}.db".format(x))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def add_and_remove_file():
|
||||||
|
"""add/remove file"""
|
||||||
|
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||||
|
paths = ["{}{}".format(file_name + "_cv", str(x).rjust(1, '0'))
|
||||||
|
for x in range(FILES_NUM)]
|
||||||
|
paths += ["{}{}".format(file_name + "_nlp", str(x).rjust(1, '0'))
|
||||||
|
for x in range(FILES_NUM)]
|
||||||
|
try:
|
||||||
|
for x in paths:
|
||||||
|
if os.path.exists("{}".format(x)):
|
||||||
|
os.remove("{}".format(x))
|
||||||
|
if os.path.exists("{}.db".format(x)):
|
||||||
|
os.remove("{}.db".format(x))
|
||||||
|
writer = FileWriter(file_name + "_nlp", FILES_NUM)
|
||||||
|
data = list(get_nlp_data("../data/mindrecord/testAclImdbData/pos",
|
||||||
|
"../data/mindrecord/testAclImdbData/vocab.txt",
|
||||||
|
10))
|
||||||
|
nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"},
|
||||||
|
"rating": {"type": "float32"},
|
||||||
|
"input_ids": {"type": "int64",
|
||||||
|
"shape": [1, -1]},
|
||||||
|
"input_mask": {"type": "int64",
|
||||||
|
"shape": [1, -1]},
|
||||||
|
"segment_ids": {"type": "int64",
|
||||||
|
"shape": [1, -1]}
|
||||||
|
}
|
||||||
|
writer.add_schema(nlp_schema_json, "nlp_schema")
|
||||||
|
writer.add_index(["id", "rating"])
|
||||||
|
writer.write_raw_data(data)
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
writer = FileWriter(file_name + "_cv", FILES_NUM)
|
||||||
|
data = get_data(CV_DIR_NAME)
|
||||||
|
cv_schema_json = {"id": {"type": "int32"},
|
||||||
|
"file_name": {"type": "string"},
|
||||||
|
"label": {"type": "int32"},
|
||||||
|
"data": {"type": "bytes"}}
|
||||||
|
writer.add_schema(cv_schema_json, "img_schema")
|
||||||
|
writer.add_index(["file_name", "label"])
|
||||||
|
writer.write_raw_data(data)
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
yield "yield_data"
|
||||||
|
except Exception as error:
|
||||||
|
for x in paths:
|
||||||
|
os.remove("{}".format(x))
|
||||||
|
os.remove("{}.db".format(x))
|
||||||
|
raise error
|
||||||
|
else:
|
||||||
|
for x in paths:
|
||||||
|
os.remove("{}".format(x))
|
||||||
|
os.remove("{}.db".format(x))
|
||||||
|
|
Loading…
Reference in New Issue