forked from mindspore-Ecosystem/mindspore
fix: may max openfiles in mindrecord and onehot error in py_transform
and cache_admin rpath
This commit is contained in:
parent
6b0e8fef6b
commit
da78a17730
|
@ -41,7 +41,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
|||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-delete-abstract-non-virtual-dtor")
|
||||
else()
|
||||
# add python lib dir to rpath
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/..:$ORIGIN/../lib:$ORIGIN/../../../..")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/..:$ORIGIN/../lib")
|
||||
endif()
|
||||
|
||||
if(ENABLE_CACHE)
|
||||
|
|
|
@ -193,7 +193,12 @@ Status ShardReader::Open() {
|
|||
|
||||
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
|
||||
fs->open(whole_path.value(), std::ios::in | std::ios::binary);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fs->good(), "Failed to open file: " + file);
|
||||
if (!fs->good()) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
!fs->fail(),
|
||||
"Maybe reach the maximum number of open files, use \"ulimit -a\" to view \"open files\" and further resize");
|
||||
RETURN_STATUS_UNEXPECTED("Failed to open file: " + file);
|
||||
}
|
||||
MS_LOG(INFO) << "Open shard file successfully.";
|
||||
file_streams_.push_back(fs);
|
||||
}
|
||||
|
@ -220,7 +225,12 @@ Status ShardReader::Open(int n_consumer) {
|
|||
|
||||
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
|
||||
fs->open(whole_path.value(), std::ios::in | std::ios::binary);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fs->good(), "Failed to open file: " + file);
|
||||
if (!fs->good()) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
!fs->fail(),
|
||||
"Maybe reach the maximum number of open files, use \"ulimit -a\" to view \"open files\" and further resize");
|
||||
RETURN_STATUS_UNEXPECTED("Failed to open file: " + file);
|
||||
}
|
||||
file_streams_random_[j].push_back(fs);
|
||||
}
|
||||
MS_LOG(INFO) << "Open shard file successfully.";
|
||||
|
|
|
@ -31,4 +31,9 @@ def main():
|
|||
cache_server = os.path.join(cache_admin_dir, "cache_server")
|
||||
os.chmod(cache_admin, stat.S_IRWXU)
|
||||
os.chmod(cache_server, stat.S_IRWXU)
|
||||
sys.exit(subprocess.call([cache_admin] + sys.argv[1:], shell=False))
|
||||
|
||||
# set LD_LIBRARY_PATH for libpython*.so
|
||||
python_lib_dir = os.path.join(os.path.dirname(mindspore.__file__), "../../..")
|
||||
os.environ['LD_LIBRARY_PATH'] = python_lib_dir + ":" + os.environ.get('LD_LIBRARY_PATH')
|
||||
|
||||
sys.exit(subprocess.call([cache_admin] + sys.argv[1:], shell=False, env=os.environ))
|
||||
|
|
|
@ -66,17 +66,43 @@ def one_hot_encoding(label, num_classes, epsilon):
|
|||
|
||||
Returns:
|
||||
img (numpy.ndarray), label after being one hot encoded and done label smoothed.
|
||||
|
||||
Examples:
|
||||
>>> # assume num_classes = 5
|
||||
>>> # 1) input np.array(3) output [0, 0, 0, 1, 0]
|
||||
>>> # 2) input np.array([4, 2, 0]) output [[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [1, 0, 0, 0, 0]]
|
||||
>>> # 3) input np.array([[4], [2], [0]]) output [[[0, 0, 0, 0, 1]], [[0, 0, 1, 0, 0][, [[1, 0, 0, 0, 0]]]
|
||||
"""
|
||||
if label > num_classes:
|
||||
raise ValueError('the num_classes is smaller than the category number.')
|
||||
if isinstance(label, np.ndarray): # the numpy should be () or (1, ) or shape: (n, 1)
|
||||
if label.dtype not in [np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64]:
|
||||
raise ValueError('the input numpy type should be int, but the input is: ' + str(label.dtype))
|
||||
|
||||
num_elements = label.size
|
||||
one_hot_label = np.zeros((num_elements, num_classes), dtype=int)
|
||||
if label.ndim == 0:
|
||||
if label >= num_classes:
|
||||
raise ValueError('the num_classes is smaller than the category number.')
|
||||
|
||||
if isinstance(label, list) is False:
|
||||
label = [label]
|
||||
for index in range(num_elements):
|
||||
one_hot_label[index, label[index]] = 1
|
||||
one_hot_label = np.zeros((num_classes), dtype=int)
|
||||
one_hot_label[label] = 1
|
||||
else:
|
||||
label_flatten = label.flatten()
|
||||
for item in label_flatten:
|
||||
if item >= num_classes:
|
||||
raise ValueError('the num_classes:' + str(num_classes) +
|
||||
' is smaller than the category number:' + str(item))
|
||||
|
||||
num_elements = label_flatten.size
|
||||
one_hot_label = np.zeros((num_elements, num_classes), dtype=int)
|
||||
for index in range(num_elements):
|
||||
one_hot_label[index][label_flatten[index]] = 1
|
||||
|
||||
new_shape = []
|
||||
for dim in label.shape:
|
||||
new_shape.append(dim)
|
||||
new_shape.append(num_classes)
|
||||
one_hot_label = one_hot_label.reshape(new_shape)
|
||||
else:
|
||||
raise ValueError('the input is invalid, it should be numpy.ndarray.')
|
||||
|
||||
return (1 - epsilon) * one_hot_label + epsilon / num_classes
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ def test_compose():
|
|||
assert test_config([[1, 0]], [ops.Duplicate(), ops.Concatenate(), ops.Duplicate(), ops.Concatenate()]) == [
|
||||
[1, 0] * 4]
|
||||
# test one python transform followed by a C transform. type after oneHot is float (mixed use-case)
|
||||
assert test_config([1, 0], [py_ops.OneHotOp(2), ops.TypeCast(mstype.int32)]) == [[[0, 1]], [[1, 0]]]
|
||||
assert test_config([1, 0], [py_ops.OneHotOp(2), ops.TypeCast(mstype.int32)]) == [[0, 1], [1, 0]]
|
||||
# test exceptions. compose, randomApply randomChoice use the same validator
|
||||
assert "op_list[0] is neither a c_transform op" in test_config([1, 0], [1, ops.TypeCast(mstype.int32)])
|
||||
# test empty op list
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_compose():
|
|||
# Test one Python transform followed by a C++ transform. Type after OneHot is a float (mixed use-case)
|
||||
assert test_config([1, 0],
|
||||
c_transforms.Compose([py_transforms.OneHotOp(2), c_transforms.TypeCast(mstype.int32)])) \
|
||||
== [[[0, 1]], [[1, 0]]]
|
||||
== [[0, 1], [1, 0]]
|
||||
|
||||
# Test exceptions.
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
|
@ -71,20 +71,20 @@ def test_compose():
|
|||
assert "op_list can not be empty." in str(error_info.value)
|
||||
|
||||
# Test Python compose op
|
||||
assert test_config([1, 0], py_transforms.Compose([py_transforms.OneHotOp(2)])) == [[[0, 1]], [[1, 0]]]
|
||||
assert test_config([1, 0], py_transforms.Compose([py_transforms.OneHotOp(2), (lambda x: x + x)])) == [[[0, 2]],
|
||||
[[2, 0]]]
|
||||
assert test_config([1, 0], py_transforms.Compose([py_transforms.OneHotOp(2)])) == [[0, 1], [1, 0]]
|
||||
assert test_config([1, 0], py_transforms.Compose([py_transforms.OneHotOp(2), (lambda x: x + x)])) == [[0, 2],
|
||||
[2, 0]]
|
||||
|
||||
# Test nested Python compose op
|
||||
assert test_config([1, 0],
|
||||
py_transforms.Compose([py_transforms.Compose([py_transforms.OneHotOp(2)]), (lambda x: x + x)])) \
|
||||
== [[[0, 2]], [[2, 0]]]
|
||||
== [[0, 2], [2, 0]]
|
||||
|
||||
# Test passing a list of Python ops without Compose wrapper
|
||||
assert test_config([1, 0],
|
||||
[py_transforms.Compose([py_transforms.OneHotOp(2)]), (lambda x: x + x)]) \
|
||||
== [[[0, 2]], [[2, 0]]]
|
||||
assert test_config([1, 0], [py_transforms.OneHotOp(2), (lambda x: x + x)]) == [[[0, 2]], [[2, 0]]]
|
||||
== [[0, 2], [2, 0]]
|
||||
assert test_config([1, 0], [py_transforms.OneHotOp(2), (lambda x: x + x)]) == [[0, 2], [2, 0]]
|
||||
|
||||
# Test a non callable function
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
|
@ -149,14 +149,14 @@ def test_c_py_compose_transforms_module():
|
|||
arr = [1, 0]
|
||||
assert test_config(arr, ["cols"], ["cols"],
|
||||
[py_transforms.OneHotOp(2), c_transforms.Mask(c_transforms.Relational.EQ, 1)]) == \
|
||||
[[[False, True]],
|
||||
[[True, False]]]
|
||||
[[False, True],
|
||||
[True, False]]
|
||||
assert test_config(arr, ["cols"], ["cols"],
|
||||
[py_transforms.OneHotOp(2), (lambda x: x + x), c_transforms.Fill(1)]) \
|
||||
== [[[1, 1]], [[1, 1]]]
|
||||
== [[1, 1], [1, 1]]
|
||||
assert test_config(arr, ["cols"], ["cols"],
|
||||
[py_transforms.OneHotOp(2), (lambda x: x + x), c_transforms.Fill(1), (lambda x: x + x)]) \
|
||||
== [[[2, 2]], [[2, 2]]]
|
||||
== [[2, 2], [2, 2]]
|
||||
assert test_config([[1, 3]], ["cols"], ["cols"],
|
||||
[c_transforms.PadEnd([3], -1), (lambda x: x + x)]) \
|
||||
== [[2, 6, -2]]
|
||||
|
@ -248,8 +248,7 @@ def test_py_transforms_with_c_vision():
|
|||
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
test_config([py_transforms.OneHotOp(20, 0.1)])
|
||||
assert "The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()" in str(
|
||||
error_info.value)
|
||||
assert "is smaller than the category number" in str(error_info.value)
|
||||
|
||||
|
||||
def test_py_vision_with_c_transforms():
|
||||
|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
|||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as data_trans
|
||||
import mindspore.dataset.transforms.py_transforms as py_trans
|
||||
import mindspore.dataset.vision.c_transforms as c_vision
|
||||
from mindspore import log as logger
|
||||
from util import dataset_equal_with_function
|
||||
|
@ -98,7 +99,119 @@ def test_one_hot_post_aug():
|
|||
|
||||
assert num_iter == 1
|
||||
|
||||
def test_one_hot_success():
|
||||
# success
|
||||
class GetDatasetGenerator:
|
||||
def __init__(self):
|
||||
np.random.seed(58)
|
||||
self.__data = np.random.sample((5, 2))
|
||||
self.__label = []
|
||||
for index in range(5):
|
||||
self.__label.append(np.array(index))
|
||||
|
||||
def __getitem__(self, index):
|
||||
return (self.__data[index], self.__label[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__data)
|
||||
|
||||
dataset = ds.GeneratorDataset(GetDatasetGenerator(), ["data", "label"], shuffle=False)
|
||||
|
||||
one_hot_encode = py_trans.OneHotOp(10)
|
||||
trans = py_trans.Compose([one_hot_encode])
|
||||
dataset = dataset.map(operations=trans, input_columns=["label"])
|
||||
|
||||
for index, item in enumerate(dataset.create_dict_iterator(output_numpy=True)):
|
||||
assert item["label"][index] == 1.0
|
||||
|
||||
def test_one_hot_success2():
|
||||
# success
|
||||
class GetDatasetGenerator:
|
||||
def __init__(self):
|
||||
np.random.seed(58)
|
||||
self.__data = np.random.sample((5, 2))
|
||||
self.__label = []
|
||||
for index in range(5):
|
||||
self.__label.append(np.array([index]))
|
||||
|
||||
def __getitem__(self, index):
|
||||
return (self.__data[index], self.__label[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__data)
|
||||
|
||||
dataset = ds.GeneratorDataset(GetDatasetGenerator(), ["data", "label"], shuffle=False)
|
||||
|
||||
one_hot_encode = py_trans.OneHotOp(10)
|
||||
trans = py_trans.Compose([one_hot_encode])
|
||||
dataset = dataset.map(operations=trans, input_columns=["label"])
|
||||
|
||||
for index, item in enumerate(dataset.create_dict_iterator(output_numpy=True)):
|
||||
logger.info(item)
|
||||
assert item["label"][0][index] == 1.0
|
||||
|
||||
def test_one_hot_success3():
|
||||
# success
|
||||
class GetDatasetGenerator:
|
||||
def __init__(self):
|
||||
np.random.seed(58)
|
||||
self.__data = np.random.sample((5, 2))
|
||||
self.__label = []
|
||||
for _ in range(5):
|
||||
value = np.ones([10, 1], dtype=np.int32)
|
||||
for i in range(10):
|
||||
value[i][0] = i
|
||||
self.__label.append(value)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return (self.__data[index], self.__label[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__data)
|
||||
|
||||
dataset = ds.GeneratorDataset(GetDatasetGenerator(), ["data", "label"], shuffle=False)
|
||||
|
||||
one_hot_encode = py_trans.OneHotOp(10)
|
||||
trans = py_trans.Compose([one_hot_encode])
|
||||
dataset = dataset.map(operations=trans, input_columns=["label"])
|
||||
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
logger.info(item)
|
||||
for i in range(10):
|
||||
assert item["label"][i][0][i] == 1.0
|
||||
|
||||
def test_one_hot_type_error():
|
||||
# type error
|
||||
class GetDatasetGenerator:
|
||||
def __init__(self):
|
||||
np.random.seed(58)
|
||||
self.__data = np.random.sample((5, 2))
|
||||
self.__label = []
|
||||
for index in range(5):
|
||||
self.__label.append(np.array(float(index)))
|
||||
|
||||
def __getitem__(self, index):
|
||||
return (self.__data[index], self.__label[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__data)
|
||||
|
||||
dataset = ds.GeneratorDataset(GetDatasetGenerator(), ["data", "label"], shuffle=False)
|
||||
|
||||
one_hot_encode = py_trans.OneHotOp(10)
|
||||
trans = py_trans.Compose([one_hot_encode])
|
||||
dataset = dataset.map(operations=trans, input_columns=["label"])
|
||||
|
||||
try:
|
||||
for index, item in enumerate(dataset.create_dict_iterator(output_numpy=True)):
|
||||
assert item["label"][index] == 1.0
|
||||
except RuntimeError as e:
|
||||
assert "the input numpy type should be int" in str(e)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_one_hot()
|
||||
test_one_hot_post_aug()
|
||||
test_one_hot_success()
|
||||
test_one_hot_success2()
|
||||
test_one_hot_success3()
|
||||
test_one_hot_type_error()
|
||||
|
|
Loading…
Reference in New Issue