add generatordataset support split
This commit is contained in:
parent
6ea2aa4e73
commit
6017b050ac
|
@ -875,6 +875,8 @@ Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
|
|||
std::shared_ptr<CVTensor> output_cv;
|
||||
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(out_image, &output_cv));
|
||||
// pad the dimension if shape information is only 2 dimensional, this is grayscale
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input_cv->Rank() == 3,
|
||||
"Pad error: invalid image shape, only support 3 channels image.");
|
||||
int num_channels = input_cv->shape()[2];
|
||||
if (input_cv->Rank() == 3 && num_channels == 1 && output_cv->Rank() == 2) output_cv->ExpandDim(2);
|
||||
*output = std::static_pointer_cast<Tensor>(output_cv);
|
||||
|
|
|
@ -2016,6 +2016,7 @@ class MapDataset(DatasetOp):
|
|||
new_op.columns_order = copy.deepcopy(self.columns_order, memodict)
|
||||
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
|
||||
new_op.parent = copy.deepcopy(self.parent, memodict)
|
||||
new_op.ms_role = copy.deepcopy(self.ms_role, memodict)
|
||||
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
|
||||
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
|
||||
new_op.cache = copy.deepcopy(self.cache, memodict)
|
||||
|
@ -3283,6 +3284,7 @@ class GeneratorDataset(MappableDataset):
|
|||
memodict[id(self)] = new_op
|
||||
new_op.children = copy.deepcopy(self.children, memodict)
|
||||
new_op.parent = copy.deepcopy(self.parent, memodict)
|
||||
new_op.ms_role = copy.deepcopy(self.ms_role, memodict)
|
||||
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
|
||||
new_op.column_types = copy.deepcopy(self.column_types, memodict)
|
||||
new_op.column_names = copy.deepcopy(self.column_names, memodict)
|
||||
|
|
|
@ -148,8 +148,20 @@ def test_pad_md5():
|
|||
filename2 = "pad_01_py_result.npz"
|
||||
save_and_check_md5(data2, filename2, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
def test_pad_exception():
|
||||
try:
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
pad_op = c_vision.Pad(150)
|
||||
data1 = data1.map(input_columns=["image"], operations=pad_op)
|
||||
for _ in data1.create_dict_iterator():
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
assert "Pad error: invalid image shape, only support 3 channels image" in str(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pad_op()
|
||||
test_pad_grayscale()
|
||||
test_pad_md5()
|
||||
test_pad_exception()
|
||||
|
|
Loading…
Reference in New Issue