add generatordataset support split

This commit is contained in:
xiefangqi 2020-07-30 14:40:42 +08:00
parent 6ea2aa4e73
commit 6017b050ac
3 changed files with 16 additions and 0 deletions

View File

@ -875,6 +875,8 @@ Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(out_image, &output_cv));
// pad the dimension if shape information is only 2 dimensional, this is grayscale
CHECK_FAIL_RETURN_UNEXPECTED(input_cv->Rank() == 3,
"Pad error: invalid image shape, only support 3 channels image.");
int num_channels = input_cv->shape()[2];
if (input_cv->Rank() == 3 && num_channels == 1 && output_cv->Rank() == 2) output_cv->ExpandDim(2);
*output = std::static_pointer_cast<Tensor>(output_cv);

View File

@ -2016,6 +2016,7 @@ class MapDataset(DatasetOp):
new_op.columns_order = copy.deepcopy(self.columns_order, memodict)
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
new_op.parent = copy.deepcopy(self.parent, memodict)
new_op.ms_role = copy.deepcopy(self.ms_role, memodict)
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
new_op.cache = copy.deepcopy(self.cache, memodict)
@ -3283,6 +3284,7 @@ class GeneratorDataset(MappableDataset):
memodict[id(self)] = new_op
new_op.children = copy.deepcopy(self.children, memodict)
new_op.parent = copy.deepcopy(self.parent, memodict)
new_op.ms_role = copy.deepcopy(self.ms_role, memodict)
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
new_op.column_types = copy.deepcopy(self.column_types, memodict)
new_op.column_names = copy.deepcopy(self.column_names, memodict)

View File

@ -148,8 +148,20 @@ def test_pad_md5():
filename2 = "pad_01_py_result.npz"
save_and_check_md5(data2, filename2, generate_golden=GENERATE_GOLDEN)
def test_pad_exception():
try:
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
pad_op = c_vision.Pad(150)
data1 = data1.map(input_columns=["image"], operations=pad_op)
for _ in data1.create_dict_iterator():
pass
assert False
except RuntimeError as e:
assert "Pad error: invalid image shape, only support 3 channels image" in str(e)
if __name__ == "__main__":
test_pad_op()
test_pad_grayscale()
test_pad_md5()
test_pad_exception()