Block caching after random pyfunc

This commit is contained in:
Lixia Chen 2021-02-08 13:25:21 -05:00
parent ed7fef5d5e
commit 0667818d9a
9 changed files with 214 additions and 4 deletions

View File

@ -30,6 +30,9 @@
#endif
#include "minddata/dataset/kernels/ir/validators.h"
#ifdef ENABLE_PYTHON
#include "minddata/dataset/kernels/py_func_op.h"
#endif
namespace mindspore {
namespace dataset {
@ -78,7 +81,12 @@ Status OneHotOperation::ValidateParams() {
std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
// PreBuiltOperation
PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) {}
PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) {
#ifdef ENABLE_PYTHON
auto pyfunc_tensor_op = std::dynamic_pointer_cast<PyFuncOp>(tensor_op);
if (pyfunc_tensor_op && pyfunc_tensor_op->IsRandom()) random_op_ = true;
#endif
}
Status PreBuiltOperation::ValidateParams() { return Status::OK(); }

View File

@ -129,5 +129,12 @@ Status PyFuncOp::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
bool PyFuncOp::IsRandom() {
bool random = true;
if (py::hasattr(py_func_ptr_, "random") && py::reinterpret_borrow<py::bool_>(py_func_ptr_.attr("random")) == false)
random = false;
return random;
}
} // namespace dataset
} // namespace mindspore

View File

@ -51,6 +51,10 @@ class PyFuncOp : public TensorOp {
std::string Name() const override { return kPyFuncOp; }
Status to_json(nlohmann::json *out_json) override;
/// \brief Check whether this pyfunc op is deterministic
/// \return True if this pyfunc op is random
bool IsRandom();
private:
py::function py_func_ptr_;
DataType::Type output_type_;

View File

@ -552,6 +552,7 @@ class PythonTokenizer:
@check_python_tokenizer
def __init__(self, tokenizer):
self.tokenizer = np.vectorize(lambda x: np.array(tokenizer(x), dtype='U'), signature='()->(n)')
self.random = False
def __call__(self, in_array):
in_array = to_str(in_array)

View File

@ -21,6 +21,11 @@ from .validators import check_one_hot_op, check_compose_list, check_random_apply
from . import py_transforms_util as util
def not_random(function):
function.random = False
return function
class OneHotOp:
"""
Apply one hot encoding transformation to the input label, make label be more smoothing and continuous.
@ -42,6 +47,7 @@ class OneHotOp:
def __init__(self, num_classes, smoothing_rate=0.0):
self.num_classes = num_classes
self.smoothing_rate = smoothing_rate
self.random = False
def __call__(self, label):
"""
@ -114,6 +120,8 @@ class Compose:
@check_compose_list
def __init__(self, transforms):
self.transforms = transforms
if all(hasattr(transform, "random") and not transform.random for transform in self.transforms):
self.random = False
@check_compose_call
def __call__(self, *args):

View File

@ -45,6 +45,11 @@ DE_PY_BORDER_TYPE = {Border.CONSTANT: 'constant',
Border.SYMMETRIC: 'symmetric'}
def not_random(function):
function.random = False
return function
class ToTensor:
"""
Convert the input NumPy image array or PIL image of shape (H, W, C) to a NumPy ndarray of shape (C, H, W).
@ -70,6 +75,7 @@ class ToTensor:
def __init__(self, output_type=np.float32):
self.output_type = output_type
self.random = False
def __call__(self, img):
"""
@ -105,6 +111,7 @@ class ToType:
def __init__(self, output_type):
self.output_type = output_type
self.random = False
def __call__(self, img):
"""
@ -132,6 +139,9 @@ class HWC2CHW:
... input_columns="image")
"""
def __init__(self):
self.random = False
def __call__(self, img):
"""
Call method.
@ -160,6 +170,9 @@ class ToPIL:
... input_columns="image")
"""
def __init__(self):
self.random = False
def __call__(self, img):
"""
Call method.
@ -187,6 +200,9 @@ class Decode:
... input_columns="image")
"""
def __init__(self):
self.random = False
def __call__(self, img):
"""
Call method.
@ -227,6 +243,7 @@ class Normalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std
self.random = False
def __call__(self, img):
"""
@ -271,6 +288,7 @@ class NormalizePad:
self.mean = mean
self.std = std
self.dtype = dtype
self.random = False
def __call__(self, img):
"""
@ -456,6 +474,7 @@ class Resize:
def __init__(self, size, interpolation=Inter.BILINEAR):
self.size = size
self.interpolation = DE_PY_INTER_MODE[interpolation]
self.random = False
def __call__(self, img):
"""
@ -550,6 +569,7 @@ class CenterCrop:
@check_crop
def __init__(self, size):
self.size = size
self.random = False
def __call__(self, img):
"""
@ -700,6 +720,7 @@ class FiveCrop:
@check_crop
def __init__(self, size):
self.size = size
self.random = False
def __call__(self, img):
"""
@ -744,6 +765,7 @@ class TenCrop:
size = (size, size)
self.size = size
self.use_vertical_flip = use_vertical_flip
self.random = False
def __call__(self, img):
"""
@ -781,6 +803,7 @@ class Grayscale:
@check_num_channels
def __init__(self, num_output_channels=1):
self.num_output_channels = num_output_channels
self.random = False
def __call__(self, img):
"""
@ -884,6 +907,7 @@ class Pad:
self.padding = padding
self.fill_value = fill_value
self.padding_mode = DE_PY_BORDER_TYPE[padding_mode]
self.random = False
def __call__(self, img):
"""
@ -1030,6 +1054,7 @@ class Cutout:
def __init__(self, length, num_patches=1):
self.length = length
self.num_patches = num_patches
self.random = False
def __call__(self, np_img):
"""
@ -1087,6 +1112,7 @@ class LinearTransformation:
def __init__(self, transformation_matrix, mean_vector):
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
self.random = False
def __call__(self, np_img):
"""
@ -1229,6 +1255,7 @@ class MixUp:
self.batch_size = batch_size
self.alpha = alpha
self.is_single = is_single
self.random = False
def __call__(self, image, label):
"""
@ -1268,6 +1295,7 @@ class RgbToHsv:
def __init__(self, is_hwc=False):
self.is_hwc = is_hwc
self.random = False
def __call__(self, rgb_imgs):
"""
@ -1304,6 +1332,7 @@ class HsvToRgb:
def __init__(self, is_hwc=False):
self.is_hwc = is_hwc
self.random = False
def __call__(self, hsv_imgs):
"""
@ -1414,6 +1443,7 @@ class AutoContrast:
def __init__(self, cutoff=0.0, ignore=None):
self.cutoff = cutoff
self.ignore = ignore
self.random = False
def __call__(self, img):
"""
@ -1443,6 +1473,9 @@ class Invert:
... input_columns="image")
"""
def __init__(self):
self.random = False
def __call__(self, img):
"""
Call method.
@ -1472,6 +1505,9 @@ class Equalize:
"""
def __init__(self):
self.random = False
def __call__(self, img):
"""
Call method.
@ -1516,6 +1552,7 @@ class UniformAugment:
def __init__(self, transforms, num_ops=2):
self.transforms = transforms
self.num_ops = num_ops
self.random = False
def __call__(self, img):
"""

View File

@ -318,6 +318,9 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_failure" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_pyfunc" 1
HandleRcExit $? 0 0
for i in $(seq 1 3)
do
test_name="test_cache_nomap_multiple_cache${i}"

View File

@ -20,6 +20,7 @@ import pytest
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_vision
import mindspore.dataset.vision.py_transforms as py_vision
from mindspore import log as logger
from util import save_and_check_md5
@ -481,7 +482,7 @@ def test_cache_map_failure7():
some_cache = ds.DatasetCache(session_id=session_id, size=0)
data = ds.GeneratorDataset(generator_1d, ["data"])
data = data.map((lambda x: x), ["data"], cache=some_cache)
data = data.map(py_vision.not_random(lambda x: x), ["data"], cache=some_cache)
data = data.repeat(4)
with pytest.raises(RuntimeError) as e:

View File

@ -17,11 +17,13 @@ Testing cache operator with non-mappable datasets
"""
import os
import itertools
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.text as text
import mindspore.dataset.vision.c_transforms as c_vision
import mindspore.dataset.vision.py_transforms as py_vision
from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
@ -41,6 +43,9 @@ CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json'
CSV_DATA_DIR = '../data/dataset/testCSV/1.csv'
TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt"
PYFUNC_DATA_DIR = ["../data/dataset/testPyfuncMap/data.data"]
PYFUNC_SCHEMA_DIR = "../data/dataset/testPyfuncMap/schema.json"
GENERATE_GOLDEN = False
@ -1633,7 +1638,7 @@ def test_cache_nomap_clue2():
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2)
ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache)
ds1 = ds1.map(py_vision.not_random(lambda x: x), ["label"], cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
@ -1710,7 +1715,7 @@ def test_cache_nomap_csv2():
ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2)
ds1 = ds1.map((lambda x: x), ["col1"], cache=some_cache)
ds1 = ds1.map(py_vision.not_random(lambda x: x), ["col1"], cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
@ -2124,6 +2129,139 @@ def test_cache_nomap_failure5():
logger.info('test_cache_nomap_failure5 Ended.\n')
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_pyfunc_lambda():
"""
Test cache after map op with a python lambda function.
Only allowed if the lambda function is wrapped by 'pyvision.not_random', otherwise an error will be raised.
Cache
|
Map(lambda function1, lambda function2)
|
TFRecord
"""
logger.info("Test cache nomap pyfunc lambda")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 12 records in it
data1 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False)
transforms = [py_vision.not_random(lambda x: x + x), py_vision.not_random(lambda x: x - 1)]
data1 = data1.map(operations=transforms, input_columns="col0", cache=some_cache)
num_iter = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 12
other_cache = ds.DatasetCache(session_id=session_id, size=0)
ds2 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False)
ds2 = ds2.map(operations=[(lambda x: x + x)], input_columns=["col0"], cache=other_cache)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in ds2.create_dict_iterator():
num_iter += 1
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
logger.info("test_cache_nomap_pyfunc_lambda Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_pyfunc_builtin():
"""
Test cache after map op with a python builtin PyFunc.
An error will be raised if the builtin pyfunc containing random operation.
Cache
|
Map([builtin pyfunc1, builtin pyfunc2])
|
TFRecord
"""
logger.info("Test cache nomap pyfunc builtin")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
ds1 = ds1.map(operations=[py_vision.Decode(), py_vision.ToTensor()], input_columns=["image"], cache=some_cache)
num_iter = 0
for _ in ds1.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 3
other_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
ds2 = ds2.map(operations=[py_vision.Decode(), py_vision.RandomCrop(224), py_vision.ToTensor()],
input_columns=["image"], cache=other_cache)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in ds2.create_dict_iterator():
num_iter += 1
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
logger.info("test_cache_nomap_pyfunc_builtin Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_pyfunc_function():
"""
Test cache after map op with a python customized function.
Only allowed if the function is decorated with 'py_vision.not_random', otherwise an error will be raised.
Cache
|
Map([function1, function2])
|
TFRecord
"""
@py_vision.not_random
def not_random_func(x):
return np.ones(x.shape, dtype=x.dtype)
def normal_func(x):
return np.ones(x.shape, dtype=x.dtype)
logger.info("Test cache nomap pyfunc function")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
ds1 = ds1.map(operations=[not_random_func, not_random_func], input_columns=["image"], cache=some_cache)
num_iter = 0
for _ in ds1.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 3
other_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
ds2 = ds2.map(operations=[not_random_func, normal_func], input_columns=["image"], cache=other_cache)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in ds2.create_dict_iterator():
num_iter += 1
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
logger.info("test_cache_nomap_pyfunc_function Ended.\n")
if __name__ == '__main__':
# This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py'
# since cache server is required to be brought up first
@ -2180,3 +2318,6 @@ if __name__ == '__main__':
test_cache_nomap_failure3()
test_cache_nomap_failure4()
test_cache_nomap_failure5()
test_cache_nomap_pyfunc_lambda()
test_cache_nomap_pyfunc_builtin()
test_cache_nomap_pyfunc_function()