[MD] Transform Unification Feature - ToTensor op: Remove Python implementation, regenerate golden files
This commit is contained in:
parent
797c21ec6c
commit
1136ee3ffd
|
@ -719,8 +719,8 @@ PYBIND_REGISTER(ToTensorOperation, 1, ([](const py::module *m) {
|
||||||
(void)
|
(void)
|
||||||
py::class_<vision::ToTensorOperation, TensorOperation, std::shared_ptr<vision::ToTensorOperation>>(
|
py::class_<vision::ToTensorOperation, TensorOperation, std::shared_ptr<vision::ToTensorOperation>>(
|
||||||
*m, "ToTensorOperation")
|
*m, "ToTensorOperation")
|
||||||
.def(py::init([](const std::string &data_type) {
|
.def(py::init([](const std::string &output_type) {
|
||||||
auto totensor = std::make_shared<vision::ToTensorOperation>(data_type);
|
auto totensor = std::make_shared<vision::ToTensorOperation>(output_type);
|
||||||
THROW_IF_ERROR(totensor->ValidateParams());
|
THROW_IF_ERROR(totensor->ValidateParams());
|
||||||
return totensor;
|
return totensor;
|
||||||
}));
|
}));
|
||||||
|
|
|
@ -260,7 +260,9 @@ Status Serdes::ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shar
|
||||||
for (nlohmann::json item : json_obj) {
|
for (nlohmann::json item : json_obj) {
|
||||||
if (item.find("python_module") != item.end()) {
|
if (item.find("python_module") != item.end()) {
|
||||||
if (Py_IsInitialized() != 0) {
|
if (Py_IsInitialized() != 0) {
|
||||||
RETURN_IF_NOT_OK(PyFuncOp::from_json(item, result));
|
std::vector<std::shared_ptr<TensorOperation>> tmp_res;
|
||||||
|
RETURN_IF_NOT_OK(PyFuncOp::from_json(item, &tmp_res));
|
||||||
|
output.insert(output.end(), tmp_res.begin(), tmp_res.end());
|
||||||
} else {
|
} else {
|
||||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(
|
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(
|
||||||
"Python module is not initialized or Pyfunction is not supported on this platform.");
|
"Python module is not initialized or Pyfunction is not supported on this platform.");
|
||||||
|
@ -275,9 +277,9 @@ Status Serdes::ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shar
|
||||||
"Invalid data, unsupported operation: " + op_name);
|
"Invalid data, unsupported operation: " + op_name);
|
||||||
RETURN_IF_NOT_OK(func_ptr_[op_name](op_params, &operation));
|
RETURN_IF_NOT_OK(func_ptr_[op_name](op_params, &operation));
|
||||||
output.push_back(operation);
|
output.push_back(operation);
|
||||||
*result = output;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*result = output;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -344,6 +346,7 @@ Serdes::InitializeFuncPtr() {
|
||||||
&(vision::SoftDvppDecodeRandomCropResizeJpegOperation::from_json);
|
&(vision::SoftDvppDecodeRandomCropResizeJpegOperation::from_json);
|
||||||
ops_ptr[vision::kSoftDvppDecodeResizeJpegOperation] = &(vision::SoftDvppDecodeResizeJpegOperation::from_json);
|
ops_ptr[vision::kSoftDvppDecodeResizeJpegOperation] = &(vision::SoftDvppDecodeResizeJpegOperation::from_json);
|
||||||
ops_ptr[vision::kSwapRedBlueOperation] = &(vision::SwapRedBlueOperation::from_json);
|
ops_ptr[vision::kSwapRedBlueOperation] = &(vision::SwapRedBlueOperation::from_json);
|
||||||
|
ops_ptr[vision::kToTensorOperation] = &(vision::ToTensorOperation::from_json);
|
||||||
ops_ptr[vision::kUniformAugOperation] = &(vision::UniformAugOperation::from_json);
|
ops_ptr[vision::kUniformAugOperation] = &(vision::UniformAugOperation::from_json);
|
||||||
ops_ptr[vision::kVerticalFlipOperation] = &(vision::VerticalFlipOperation::from_json);
|
ops_ptr[vision::kVerticalFlipOperation] = &(vision::VerticalFlipOperation::from_json);
|
||||||
ops_ptr[transforms::kFillOperation] = &(transforms::FillOperation::from_json);
|
ops_ptr[transforms::kFillOperation] = &(transforms::FillOperation::from_json);
|
||||||
|
|
|
@ -131,6 +131,7 @@
|
||||||
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_random_crop_resize_jpeg_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_random_crop_resize_jpeg_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_resize_jpeg_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/softdvpp_decode_resize_jpeg_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/swap_red_blue_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/swap_red_blue_ir.h"
|
||||||
|
#include "minddata/dataset/kernels/ir/vision/to_tensor_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/uniform_aug_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/uniform_aug_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/vertical_flip_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/vertical_flip_ir.h"
|
||||||
#include "minddata/dataset/text/ir/kernels/text_ir.h"
|
#include "minddata/dataset/text/ir/kernels/text_ir.h"
|
||||||
|
|
|
@ -27,7 +27,7 @@ namespace dataset {
|
||||||
Status ToTensorOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
Status ToTensorOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
IO_CHECK(input, output);
|
IO_CHECK(input, output);
|
||||||
// Rescale and convert HWC to CHW format
|
// Rescale and convert HWC to CHW format
|
||||||
return ToTensor(input, output, data_type_);
|
return ToTensor(input, output, output_type_);
|
||||||
}
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -28,9 +28,9 @@ namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
class ToTensorOp : public TensorOp {
|
class ToTensorOp : public TensorOp {
|
||||||
public:
|
public:
|
||||||
explicit ToTensorOp(const DataType &data_type) : data_type_(data_type) {}
|
explicit ToTensorOp(const DataType &output_type) : output_type_(output_type) {}
|
||||||
|
|
||||||
explicit ToTensorOp(const std::string &data_type) { data_type_ = DataType(data_type); }
|
explicit ToTensorOp(const std::string &output_type) { output_type_ = DataType(output_type); }
|
||||||
|
|
||||||
~ToTensorOp() override = default;
|
~ToTensorOp() override = default;
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ class ToTensorOp : public TensorOp {
|
||||||
std::string Name() const override { return kToTensorOp; }
|
std::string Name() const override { return kToTensorOp; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DataType data_type_;
|
DataType output_type_;
|
||||||
};
|
};
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -23,9 +23,9 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
ToTensorOperation::ToTensorOperation(const std::string &data_type) {
|
ToTensorOperation::ToTensorOperation(const std::string &output_type) {
|
||||||
DataType temp_data_type(data_type);
|
DataType temp_output_type(output_type);
|
||||||
data_type_ = temp_data_type;
|
output_type_ = temp_output_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
ToTensorOperation::~ToTensorOperation() = default;
|
ToTensorOperation::~ToTensorOperation() = default;
|
||||||
|
@ -33,26 +33,26 @@ ToTensorOperation::~ToTensorOperation() = default;
|
||||||
std::string ToTensorOperation::Name() const { return kToTensorOperation; }
|
std::string ToTensorOperation::Name() const { return kToTensorOperation; }
|
||||||
|
|
||||||
Status ToTensorOperation::ValidateParams() {
|
Status ToTensorOperation::ValidateParams() {
|
||||||
if (data_type_ == DataType::DE_UNKNOWN) {
|
if (output_type_ == DataType::DE_UNKNOWN) {
|
||||||
std::string err_msg = "ToTensor: Invalid data type";
|
std::string err_msg = "ToTensor: Invalid data type for output_type parameter.";
|
||||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<TensorOp> ToTensorOperation::Build() { return std::make_shared<ToTensorOp>(data_type_); }
|
std::shared_ptr<TensorOp> ToTensorOperation::Build() { return std::make_shared<ToTensorOp>(output_type_); }
|
||||||
|
|
||||||
Status ToTensorOperation::to_json(nlohmann::json *out_json) {
|
Status ToTensorOperation::to_json(nlohmann::json *out_json) {
|
||||||
nlohmann::json args;
|
nlohmann::json args;
|
||||||
args["data_type"] = data_type_.ToString();
|
args["output_type"] = output_type_.ToString();
|
||||||
*out_json = args;
|
*out_json = args;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ToTensorOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
|
Status ToTensorOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
|
||||||
RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "data_type", kToTensorOperation));
|
RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "output_type", kToTensorOperation));
|
||||||
std::string data_type = op_params["data_type"];
|
std::string output_type = op_params["output_type"];
|
||||||
*operation = std::make_shared<vision::ToTensorOperation>(data_type);
|
*operation = std::make_shared<vision::ToTensorOperation>(output_type);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
|
|
|
@ -35,7 +35,7 @@ constexpr char kToTensorOperation[] = "ToTensor";
|
||||||
|
|
||||||
class ToTensorOperation : public TensorOperation {
|
class ToTensorOperation : public TensorOperation {
|
||||||
public:
|
public:
|
||||||
explicit ToTensorOperation(const std::string &data_type);
|
explicit ToTensorOperation(const std::string &output_type);
|
||||||
|
|
||||||
~ToTensorOperation();
|
~ToTensorOperation();
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ class ToTensorOperation : public TensorOperation {
|
||||||
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DataType data_type_;
|
DataType output_type_;
|
||||||
};
|
};
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|
|
@ -3456,18 +3456,7 @@ class ToTensor(TensorOperation, PyTensorOperation):
|
||||||
output_type = nptype_to_detype(output_type)
|
output_type = nptype_to_detype(output_type)
|
||||||
self.output_type = str(output_type)
|
self.output_type = str(output_type)
|
||||||
self.random = False
|
self.random = False
|
||||||
|
self.implementation = Implementation.C
|
||||||
def execute_py(self, img):
|
|
||||||
"""
|
|
||||||
Execute method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img (Union[PIL Image, numpy.ndarray]): PIL Image or numpy.ndarray to be type converted.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
numpy.ndarray, converted numpy.ndarray with desired type.
|
|
||||||
"""
|
|
||||||
return util.to_tensor(img, self.output_type)
|
|
||||||
|
|
||||||
def parse(self):
|
def parse(self):
|
||||||
return cde.ToTensorOperation(self.output_type)
|
return cde.ToTensorOperation(self.output_type)
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -95,7 +95,8 @@ def test_five_crop_error_msg():
|
||||||
with pytest.raises(RuntimeError) as info:
|
with pytest.raises(RuntimeError) as info:
|
||||||
for _ in data:
|
for _ in data:
|
||||||
pass
|
pass
|
||||||
error_msg = "TypeError: execute_py() takes 2 positional arguments but 6 were given"
|
error_msg = \
|
||||||
|
"Unexpected error. map operation: [ToTensor] failed. The op is OneToOne, can only accept one tensor as input."
|
||||||
assert error_msg in str(info.value)
|
assert error_msg in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -546,7 +546,9 @@ def test_random_crop_09():
|
||||||
with pytest.raises(RuntimeError) as error_info:
|
with pytest.raises(RuntimeError) as error_info:
|
||||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
pass
|
pass
|
||||||
assert "img should be PIL image." in str(error_info.value)
|
error_msg = \
|
||||||
|
"Unexpected error. map operation: [RandomCrop] failed. Pad: input shape is not <H,W,C> or <H, W>, got rank: 3"
|
||||||
|
assert error_msg in str(error_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -22,7 +22,6 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from util import config_get_set_num_parallel_workers, config_get_set_seed
|
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
|
@ -30,14 +29,19 @@ import mindspore.dataset.transforms.transforms as transforms
|
||||||
import mindspore.dataset.vision.transforms as vision
|
import mindspore.dataset.vision.transforms as vision
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore.dataset.vision import Inter
|
from mindspore.dataset.vision import Inter
|
||||||
|
from util import config_get_set_num_parallel_workers, config_get_set_seed
|
||||||
|
|
||||||
|
|
||||||
def test_serdes_imagefolder_dataset(remove_json_files=True):
|
def test_serdes_imagefolder_dataset(remove_json_files=True):
|
||||||
"""
|
"""
|
||||||
Test simulating resnet50 dataset pipeline.
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize with dataset pipeline that simulates ResNet50
|
||||||
|
Expectation: Output verified for multiple deserialized pipelines
|
||||||
"""
|
"""
|
||||||
data_dir = "../data/dataset/testPK/data"
|
data_dir = "../data/dataset/testPK/data"
|
||||||
ds.config.set_seed(1)
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
# define data augmentation parameters
|
# define data augmentation parameters
|
||||||
rescale = 1.0 / 255.0
|
rescale = 1.0 / 255.0
|
||||||
|
@ -98,6 +102,10 @@ def test_serdes_imagefolder_dataset(remove_json_files=True):
|
||||||
logger.info("Number of data in data1: {}".format(num_samples))
|
logger.info("Number of data in data1: {}".format(num_samples))
|
||||||
assert num_samples == 11
|
assert num_samples == 11
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
# Remove the generated json file
|
# Remove the generated json file
|
||||||
if remove_json_files:
|
if remove_json_files:
|
||||||
delete_json_files("imagenet_dataset_pipeline")
|
delete_json_files("imagenet_dataset_pipeline")
|
||||||
|
@ -105,10 +113,14 @@ def test_serdes_imagefolder_dataset(remove_json_files=True):
|
||||||
|
|
||||||
def test_serdes_mnist_dataset(remove_json_files=True):
|
def test_serdes_mnist_dataset(remove_json_files=True):
|
||||||
"""
|
"""
|
||||||
Test serdes on mnist dataset pipeline.
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize with MnistDataset pipeline
|
||||||
|
Expectation: Output verified for multiple deserialized pipelines
|
||||||
"""
|
"""
|
||||||
data_dir = "../data/dataset/testMnistData"
|
data_dir = "../data/dataset/testMnistData"
|
||||||
ds.config.set_seed(1)
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
data1 = ds.MnistDataset(data_dir, num_samples=100)
|
data1 = ds.MnistDataset(data_dir, num_samples=100)
|
||||||
one_hot_encode = transforms.OneHot(10) # num_classes is input argument
|
one_hot_encode = transforms.OneHot(10) # num_classes is input argument
|
||||||
|
@ -140,15 +152,22 @@ def test_serdes_mnist_dataset(remove_json_files=True):
|
||||||
logger.info("mnist total num samples is {}".format(str(num)))
|
logger.info("mnist total num samples is {}".format(str(num)))
|
||||||
assert num == 10
|
assert num == 10
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
if remove_json_files:
|
if remove_json_files:
|
||||||
delete_json_files("mnist_dataset_pipeline")
|
delete_json_files("mnist_dataset_pipeline")
|
||||||
|
|
||||||
|
|
||||||
def test_serdes_cifar10_dataset(remove_json_files=True):
|
def test_serdes_cifar10_dataset(remove_json_files=True):
|
||||||
"""
|
"""
|
||||||
Test serdes on Cifar10 dataset pipeline
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize with Cifar10Dataset pipeline
|
||||||
|
Expectation: Output verified for multiple deserialized pipelines
|
||||||
"""
|
"""
|
||||||
data_dir = "../data/dataset/testCifar10Data"
|
data_dir = "../data/dataset/testCifar10Data"
|
||||||
|
|
||||||
original_seed = config_get_set_seed(1)
|
original_seed = config_get_set_seed(1)
|
||||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
@ -179,19 +198,22 @@ def test_serdes_cifar10_dataset(remove_json_files=True):
|
||||||
|
|
||||||
assert num_samples == 2
|
assert num_samples == 2
|
||||||
|
|
||||||
# Restore configuration num_parallel_workers
|
# Restore configuration
|
||||||
ds.config.set_seed(original_seed)
|
ds.config.set_seed(original_seed)
|
||||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
if remove_json_files:
|
if remove_json_files:
|
||||||
delete_json_files("cifar10_dataset_pipeline")
|
delete_json_files("cifar10_dataset_pipeline")
|
||||||
|
|
||||||
|
|
||||||
def test_serdes_celeba_dataset(remove_json_files=True):
|
def test_serdes_celeba_dataset(remove_json_files=True):
|
||||||
"""
|
"""
|
||||||
Test serdes on Celeba dataset pipeline.
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize with CelebADataset pipeline
|
||||||
|
Expectation: Output verified for multiple deserialized pipelines
|
||||||
"""
|
"""
|
||||||
DATA_DIR = "../data/dataset/testCelebAData/"
|
data_dir = "../data/dataset/testCelebAData/"
|
||||||
data1 = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0)
|
data1 = ds.CelebADataset(data_dir, decode=True, num_shards=1, shard_id=0)
|
||||||
# define map operations
|
# define map operations
|
||||||
data1 = data1.repeat(2)
|
data1 = data1.repeat(2)
|
||||||
center_crop = vision.CenterCrop((80, 80))
|
center_crop = vision.CenterCrop((80, 80))
|
||||||
|
@ -214,11 +236,13 @@ def test_serdes_celeba_dataset(remove_json_files=True):
|
||||||
|
|
||||||
def test_serdes_csv_dataset(remove_json_files=True):
|
def test_serdes_csv_dataset(remove_json_files=True):
|
||||||
"""
|
"""
|
||||||
Test serdes on Csvdataset pipeline.
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize with CSVDataset pipeline
|
||||||
|
Expectation: Output verified for multiple deserialized pipelines
|
||||||
"""
|
"""
|
||||||
DATA_DIR = "../data/dataset/testCSV/1.csv"
|
data_dir = "../data/dataset/testCSV/1.csv"
|
||||||
data1 = ds.CSVDataset(
|
data1 = ds.CSVDataset(
|
||||||
DATA_DIR,
|
data_dir,
|
||||||
column_defaults=["1", "2", "3", "4"],
|
column_defaults=["1", "2", "3", "4"],
|
||||||
column_names=['col1', 'col2', 'col3', 'col4'],
|
column_names=['col1', 'col2', 'col3', 'col4'],
|
||||||
shuffle=False)
|
shuffle=False)
|
||||||
|
@ -243,9 +267,12 @@ def test_serdes_csv_dataset(remove_json_files=True):
|
||||||
|
|
||||||
def test_serdes_voc_dataset(remove_json_files=True):
|
def test_serdes_voc_dataset(remove_json_files=True):
|
||||||
"""
|
"""
|
||||||
Test serdes on VOC dataset pipeline.
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize with VOCDataset pipeline
|
||||||
|
Expectation: Output verified for multiple deserialized pipelines
|
||||||
"""
|
"""
|
||||||
data_dir = "../data/dataset/testVOC2012"
|
data_dir = "../data/dataset/testVOC2012"
|
||||||
|
|
||||||
original_seed = config_get_set_seed(1)
|
original_seed = config_get_set_seed(1)
|
||||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
@ -270,20 +297,25 @@ def test_serdes_voc_dataset(remove_json_files=True):
|
||||||
|
|
||||||
assert num_samples == 7
|
assert num_samples == 7
|
||||||
|
|
||||||
# Restore configuration num_parallel_workers
|
# Restore configuration
|
||||||
ds.config.set_seed(original_seed)
|
ds.config.set_seed(original_seed)
|
||||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
if remove_json_files:
|
if remove_json_files:
|
||||||
delete_json_files("voc_dataset_pipeline")
|
delete_json_files("voc_dataset_pipeline")
|
||||||
|
|
||||||
|
|
||||||
def test_serdes_zip_dataset(remove_json_files=True):
|
def test_serdes_zip_dataset(remove_json_files=True):
|
||||||
"""
|
"""
|
||||||
Test serdes on zip dataset pipeline.
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize with zipped pipeline
|
||||||
|
Expectation: Output verified for multiple deserialized pipelines
|
||||||
"""
|
"""
|
||||||
files = ["../data/dataset/testTFTestAllTypes/test.data"]
|
files = ["../data/dataset/testTFTestAllTypes/test.data"]
|
||||||
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
|
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
|
||||||
ds.config.set_seed(1)
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
ds0 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
|
ds0 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
|
||||||
data1 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
|
data1 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
|
||||||
|
@ -318,28 +350,35 @@ def test_serdes_zip_dataset(remove_json_files=True):
|
||||||
rows += 1
|
rows += 1
|
||||||
assert rows == 12
|
assert rows == 12
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
if remove_json_files:
|
if remove_json_files:
|
||||||
delete_json_files("zip_dataset_pipeline")
|
delete_json_files("zip_dataset_pipeline")
|
||||||
|
|
||||||
|
|
||||||
def test_serdes_random_crop():
|
def test_serdes_random_crop():
|
||||||
"""
|
"""
|
||||||
Test serdes on RandomCrop pipeline.
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipeline with RandomCrop op
|
||||||
|
Expectation: Output verified for multiple deserialized pipelines
|
||||||
"""
|
"""
|
||||||
logger.info("test_random_crop")
|
logger.info("test_random_crop")
|
||||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
schema_dir = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
original_seed = config_get_set_seed(1)
|
original_seed = config_get_set_seed(1)
|
||||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
# First dataset
|
# First dataset
|
||||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
|
data1 = ds.TFRecordDataset(data_dir, schema_dir, columns_list=["image"])
|
||||||
decode_op = vision.Decode()
|
decode_op = vision.Decode()
|
||||||
random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
|
random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
|
||||||
data1 = data1.map(operations=decode_op, input_columns="image")
|
data1 = data1.map(operations=decode_op, input_columns="image")
|
||||||
data1 = data1.map(operations=random_crop_op, input_columns="image")
|
data1 = data1.map(operations=random_crop_op, input_columns="image")
|
||||||
|
|
||||||
# Serializing into python dictionary
|
# Serializing into Python dictionary
|
||||||
ds1_dict = ds.serialize(data1)
|
ds1_dict = ds.serialize(data1)
|
||||||
# Serializing into json object
|
# Serializing into json object
|
||||||
_ = json.dumps(ds1_dict, indent=2)
|
_ = json.dumps(ds1_dict, indent=2)
|
||||||
|
@ -348,7 +387,7 @@ def test_serdes_random_crop():
|
||||||
data1_1 = ds.deserialize(input_dict=ds1_dict)
|
data1_1 = ds.deserialize(input_dict=ds1_dict)
|
||||||
|
|
||||||
# Second dataset
|
# Second dataset
|
||||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
|
data2 = ds.TFRecordDataset(data_dir, schema_dir, columns_list=["image"])
|
||||||
data2 = data2.map(operations=decode_op, input_columns="image")
|
data2 = data2.map(operations=decode_op, input_columns="image")
|
||||||
|
|
||||||
for item1, item1_1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
for item1, item1_1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||||
|
@ -357,14 +396,16 @@ def test_serdes_random_crop():
|
||||||
np.testing.assert_array_equal(item1['image'], item1_1['image'])
|
np.testing.assert_array_equal(item1['image'], item1_1['image'])
|
||||||
_ = item2["image"]
|
_ = item2["image"]
|
||||||
|
|
||||||
# Restore configuration num_parallel_workers
|
# Restore configuration
|
||||||
ds.config.set_seed(original_seed)
|
ds.config.set_seed(original_seed)
|
||||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
|
||||||
def test_serdes_to_device(remove_json_files=True):
|
def test_serdes_to_device(remove_json_files=True):
|
||||||
"""
|
"""
|
||||||
Test serdes on transfer dataset pipeline.
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipeline with to_device op
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
"""
|
"""
|
||||||
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
@ -375,15 +416,20 @@ def test_serdes_to_device(remove_json_files=True):
|
||||||
|
|
||||||
def test_serdes_pyvision(remove_json_files=True):
|
def test_serdes_pyvision(remove_json_files=True):
|
||||||
"""
|
"""
|
||||||
Test serdes on py_transform pipeline.
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipelines with Python implementation selected for vision ops
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
"""
|
"""
|
||||||
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
||||||
transforms1 = [
|
transforms1 = [
|
||||||
vision.Decode(True),
|
vision.Decode(True),
|
||||||
vision.CenterCrop([32, 32]),
|
vision.CenterCrop([32, 32])
|
||||||
vision.ToTensor()
|
|
||||||
]
|
]
|
||||||
transforms2 = [
|
transforms2 = [
|
||||||
vision.RandomColorAdjust(),
|
vision.RandomColorAdjust(),
|
||||||
|
@ -393,26 +439,257 @@ def test_serdes_pyvision(remove_json_files=True):
|
||||||
data1 = data1.map(operations=transforms.Compose(transforms1), input_columns=["image"])
|
data1 = data1.map(operations=transforms.Compose(transforms1), input_columns=["image"])
|
||||||
data1 = data1.map(operations=transforms.RandomApply(transforms2), input_columns=["image"])
|
data1 = data1.map(operations=transforms.RandomApply(transforms2), input_columns=["image"])
|
||||||
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
|
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
if remove_json_files:
|
||||||
|
delete_json_files("pyvision_dataset_pipeline")
|
||||||
|
|
||||||
|
|
||||||
|
def test_serdes_pyfunc(remove_json_files=True):
|
||||||
|
"""
|
||||||
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipelines with Python functions
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
|
"""
|
||||||
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
data2 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
data2 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
||||||
data2 = data2.map(operations=(lambda x, y, z: (
|
data2 = data2.map(operations=(lambda x, y, z: (
|
||||||
np.array(x).flatten().reshape(10, 39),
|
np.array(x).flatten().reshape(10, 39),
|
||||||
np.array(y).flatten().reshape(10, 39),
|
np.array(y).flatten().reshape(10, 39),
|
||||||
np.array(z).flatten().reshape(10, 1)
|
np.array(z).flatten().reshape(10, 1)
|
||||||
)))
|
)))
|
||||||
ds.serialize(data2, "pyvision_dataset_pipeline.json")
|
ds.serialize(data2, "pyfunc_dataset_pipeline.json")
|
||||||
assert validate_jsonfile("pyvision_dataset_pipeline.json") is True
|
assert validate_jsonfile("pyfunc_dataset_pipeline.json") is True
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
if remove_json_files:
|
if remove_json_files:
|
||||||
delete_json_files("pyvision_dataset_pipeline")
|
delete_json_files("pyfunc_dataset_pipeline")
|
||||||
|
|
||||||
|
|
||||||
|
def test_serdes_inter_mixed_map(remove_json_files=True):
|
||||||
|
"""
|
||||||
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipelines in which each map op has the same
|
||||||
|
implementation (Python or C++) of ops
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
|
"""
|
||||||
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
||||||
|
# The following map op uses Python implementation of ops
|
||||||
|
data1 = data1.map(operations=[vision.Decode(True), vision.CenterCrop([24, 24])], input_columns=["image"])
|
||||||
|
# The following map op uses C++ implementation of ToTensor op
|
||||||
|
data1 = data1.map(operations=[vision.ToTensor()], input_columns=["image"])
|
||||||
|
# The following map op uses C++ implementation of ops
|
||||||
|
data1 = data1.map(operations=[vision.HorizontalFlip(), vision.VerticalFlip()], input_columns=["image"])
|
||||||
|
# The following map op uses Python implementation of ops
|
||||||
|
data1 = data1.map(operations=[vision.ToPIL(), vision.FiveCrop((18, 22))], input_columns=["image"])
|
||||||
|
|
||||||
|
util_check_serialize_deserialize_file(data1, "inter_mixed_map_pipeline", remove_json_files)
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
if remove_json_files:
|
||||||
|
delete_json_files("inter_mixed_map_pipeline")
|
||||||
|
|
||||||
|
|
||||||
|
def test_serdes_intra_mixed_py2c_map(remove_json_files=True):
|
||||||
|
"""
|
||||||
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipelines in which each map op has a mix of Python implementation
|
||||||
|
then C++ implementation of ops
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
|
"""
|
||||||
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
||||||
|
# The following map op uses mixed implementation of ops:
|
||||||
|
# - Decode - Python implementation
|
||||||
|
# - CenterCrop - Python Implementation
|
||||||
|
# - ToTensor - C++ implementation
|
||||||
|
# - RandonHorizontalFlip - C++ implementation
|
||||||
|
# - VerticalFlip - C++ implementation
|
||||||
|
transforms_list = [vision.Decode(True),
|
||||||
|
vision.CenterCrop([24, 24]),
|
||||||
|
vision.ToTensor(),
|
||||||
|
vision.RandomHorizontalFlip(),
|
||||||
|
vision.VerticalFlip()]
|
||||||
|
data1 = data1.map(operations=transforms_list, input_columns=["image"])
|
||||||
|
data2 = util_check_serialize_deserialize_file(data1, "intra_mixed_py2c_map_pipeline", False)
|
||||||
|
|
||||||
|
num_itr = 0
|
||||||
|
# Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
|
||||||
|
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||||
|
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||||
|
np.testing.assert_array_equal(item1['image'], item2['image'])
|
||||||
|
num_itr += 1
|
||||||
|
assert num_itr == 3
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
if remove_json_files:
|
||||||
|
delete_json_files("intra_mixed_py2c_map_pipeline")
|
||||||
|
|
||||||
|
|
||||||
|
def test_serdes_intra_mixed_c2py_map(remove_json_files=True):
|
||||||
|
"""
|
||||||
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipelines in which each map op has a mix of C++ implementation
|
||||||
|
then Python implementation of ops
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
|
"""
|
||||||
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
||||||
|
# The following map op uses mixed implementation of ops:
|
||||||
|
# - Decode - C++ implementation
|
||||||
|
# - RandomSolarize - C++ implementation
|
||||||
|
# - ToPIL - Python Implementation
|
||||||
|
# - CenterCrop - Python Implementation
|
||||||
|
transforms_list = [vision.Decode(),
|
||||||
|
vision.RandomSolarize((0, 127)),
|
||||||
|
vision.ToPIL(),
|
||||||
|
vision.CenterCrop([64, 64])]
|
||||||
|
data1 = data1.map(operations=transforms_list, input_columns=["image"])
|
||||||
|
data2 = util_check_serialize_deserialize_file(data1, "intra_mixed_c2py_map_pipeline", False)
|
||||||
|
|
||||||
|
num_itr = 0
|
||||||
|
# Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
|
||||||
|
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||||
|
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||||
|
np.testing.assert_array_equal(item1['image'], item2['image'])
|
||||||
|
num_itr += 1
|
||||||
|
assert num_itr == 3
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
if remove_json_files:
|
||||||
|
delete_json_files("intra_mixed_c2py_map_pipeline")
|
||||||
|
|
||||||
|
|
||||||
|
def test_serdes_totensor_normalize(remove_json_files=True):
|
||||||
|
"""
|
||||||
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipelines in which each map op has common scenario with
|
||||||
|
ToTensor and Normalize ops
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
|
"""
|
||||||
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
||||||
|
# The following map op uses mixed implementation of ops:
|
||||||
|
# - Decode - Python implementation
|
||||||
|
# - CenterCrop - Python Implementation
|
||||||
|
# - ToTensor - C++ implementation
|
||||||
|
# - Normalize - C++ implementation
|
||||||
|
transforms_list = [vision.Decode(True),
|
||||||
|
vision.CenterCrop([30, 50]),
|
||||||
|
vision.ToTensor(),
|
||||||
|
vision.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], is_hwc=False)]
|
||||||
|
data1 = data1.map(operations=transforms_list, input_columns=["image"])
|
||||||
|
data2 = util_check_serialize_deserialize_file(data1, "totensor_normalize_pipeline", False)
|
||||||
|
|
||||||
|
num_itr = 0
|
||||||
|
# Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
|
||||||
|
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||||
|
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||||
|
np.testing.assert_array_equal(item1['image'], item2['image'])
|
||||||
|
num_itr += 1
|
||||||
|
assert num_itr == 3
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
if remove_json_files:
|
||||||
|
delete_json_files("totensor_normalize_pipeline")
|
||||||
|
|
||||||
|
|
||||||
|
def test_serdes_tonumpy(remove_json_files=True):
|
||||||
|
"""
|
||||||
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipelines with ToNumpy op
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
|
"""
|
||||||
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
||||||
|
# The following map op uses mixed implementation of ops:
|
||||||
|
# - Decode - Python implementation
|
||||||
|
# - CenterCrop - Python Implementation
|
||||||
|
# - ToNumpy - C++ implementation set
|
||||||
|
# - Crop - C++ implementation
|
||||||
|
transforms_list = [vision.Decode(to_pil=True),
|
||||||
|
vision.CenterCrop((200, 300)),
|
||||||
|
vision.ToNumpy(),
|
||||||
|
vision.Crop([5, 5], [40, 60])]
|
||||||
|
data1 = data1.map(operations=transforms_list, input_columns=["image"])
|
||||||
|
data2 = util_check_serialize_deserialize_file(data1, "tonumpy_pipeline", False)
|
||||||
|
|
||||||
|
num_itr = 0
|
||||||
|
# Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
|
||||||
|
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||||
|
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||||
|
np.testing.assert_array_equal(item1['image'], item2['image'])
|
||||||
|
num_itr += 1
|
||||||
|
assert num_itr == 3
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
if remove_json_files:
|
||||||
|
delete_json_files("tonumpy_pipeline")
|
||||||
|
|
||||||
|
|
||||||
def test_serdes_uniform_augment(remove_json_files=True):
|
def test_serdes_uniform_augment(remove_json_files=True):
|
||||||
"""
|
"""
|
||||||
Test serdes on uniform augment.
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipeline with UniformAugment op
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
"""
|
"""
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
data_dir = "../data/dataset/testPK/data"
|
data_dir = "../data/dataset/testPK/data"
|
||||||
data = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
data = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||||
ds.config.set_seed(1)
|
|
||||||
|
|
||||||
transforms_ua = [vision.RandomHorizontalFlip(),
|
transforms_ua = [vision.RandomHorizontalFlip(),
|
||||||
vision.RandomVerticalFlip(),
|
vision.RandomVerticalFlip(),
|
||||||
|
@ -426,11 +703,18 @@ def test_serdes_uniform_augment(remove_json_files=True):
|
||||||
data = data.map(operations=transforms_all, input_columns="image", num_parallel_workers=1)
|
data = data.map(operations=transforms_all, input_columns="image", num_parallel_workers=1)
|
||||||
util_check_serialize_deserialize_file(data, "uniform_augment_pipeline", remove_json_files)
|
util_check_serialize_deserialize_file(data, "uniform_augment_pipeline", remove_json_files)
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
|
||||||
def skip_test_serdes_fill(remove_json_files=True):
|
def skip_test_serdes_fill(remove_json_files=True):
|
||||||
"""
|
"""
|
||||||
Test serdes on Fill data transform.
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipelines with Fill op
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
yield (np.array([4, 5, 6, 7], dtype=np.int32),)
|
yield (np.array([4, 5, 6, 7], dtype=np.int32),)
|
||||||
|
|
||||||
|
@ -447,7 +731,9 @@ def skip_test_serdes_fill(remove_json_files=True):
|
||||||
|
|
||||||
def test_serdes_exception():
|
def test_serdes_exception():
|
||||||
"""
|
"""
|
||||||
Test exception case in serdes
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test exception cases
|
||||||
|
Expectation: Correct error is verified
|
||||||
"""
|
"""
|
||||||
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
@ -505,6 +791,7 @@ def delete_json_files(filename):
|
||||||
except IOError:
|
except IOError:
|
||||||
logger.info("Error while deleting: {}".format(f))
|
logger.info("Error while deleting: {}".format(f))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_serdes_imagefolder_dataset()
|
test_serdes_imagefolder_dataset()
|
||||||
test_serdes_mnist_dataset()
|
test_serdes_mnist_dataset()
|
||||||
|
@ -516,6 +803,12 @@ if __name__ == '__main__':
|
||||||
test_serdes_random_crop()
|
test_serdes_random_crop()
|
||||||
test_serdes_to_device()
|
test_serdes_to_device()
|
||||||
test_serdes_pyvision()
|
test_serdes_pyvision()
|
||||||
|
test_serdes_pyfunc()
|
||||||
|
test_serdes_inter_mixed_map()
|
||||||
|
test_serdes_intra_mixed_py2c_map()
|
||||||
|
test_serdes_intra_mixed_c2py_map()
|
||||||
|
test_serdes_totensor_normalize()
|
||||||
|
test_serdes_tonumpy()
|
||||||
test_serdes_uniform_augment()
|
test_serdes_uniform_augment()
|
||||||
skip_test_serdes_fill()
|
skip_test_serdes_fill()
|
||||||
test_serdes_exception()
|
test_serdes_exception()
|
||||||
|
|
|
@ -179,7 +179,8 @@ def test_ten_crop_wrong_img_error_msg():
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as info:
|
with pytest.raises(RuntimeError) as info:
|
||||||
data.create_tuple_iterator(num_epochs=1).__next__()
|
data.create_tuple_iterator(num_epochs=1).__next__()
|
||||||
error_msg = "TypeError: execute_py() takes 2 positional arguments but 11 were given"
|
error_msg = \
|
||||||
|
"Unexpected error. map operation: [ToTensor] failed. The op is OneToOne, can only accept one tensor as input."
|
||||||
assert error_msg in str(info.value)
|
assert error_msg in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,254 @@
|
||||||
|
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Testing dataset serialize and deserialize in DE
|
||||||
|
"""
|
||||||
|
import filecmp
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.transforms.py_transforms as py_transforms
|
||||||
|
import mindspore.dataset.vision.c_transforms as c_vision
|
||||||
|
import mindspore.dataset.vision.py_transforms as py_vision
|
||||||
|
from mindspore import log as logger
|
||||||
|
from ..dataset.util import config_get_set_num_parallel_workers, config_get_set_seed
|
||||||
|
|
||||||
|
|
||||||
|
def test_serdes_pyvision(remove_json_files=True):
|
||||||
|
"""
|
||||||
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipelines with Python vision ops
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
|
"""
|
||||||
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
||||||
|
transforms1 = [
|
||||||
|
py_vision.Decode(),
|
||||||
|
py_vision.CenterCrop([32, 32])
|
||||||
|
]
|
||||||
|
transforms2 = [
|
||||||
|
py_vision.RandomColorAdjust(),
|
||||||
|
py_vision.FiveCrop(1),
|
||||||
|
py_vision.Grayscale()
|
||||||
|
]
|
||||||
|
data1 = data1.map(operations=py_transforms.Compose(transforms1), input_columns=["image"])
|
||||||
|
data1 = data1.map(operations=py_transforms.RandomApply(transforms2), input_columns=["image"])
|
||||||
|
util_check_serialize_deserialize_file(data1, "depr_pyvision_dataset_pipeline", remove_json_files)
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
if remove_json_files:
|
||||||
|
delete_json_files("depr_pyvision_dataset_pipeline")
|
||||||
|
|
||||||
|
|
||||||
|
def test_serdes_pyfunc(remove_json_files=True):
|
||||||
|
"""
|
||||||
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipelines with Python functions
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
|
"""
|
||||||
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
data2 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
||||||
|
data2 = data2.map(operations=(lambda x, y, z: (
|
||||||
|
np.array(x).flatten().reshape(10, 39),
|
||||||
|
np.array(y).flatten().reshape(10, 39),
|
||||||
|
np.array(z).flatten().reshape(10, 1)
|
||||||
|
)))
|
||||||
|
ds.serialize(data2, "pyfunc_dataset_pipeline.json")
|
||||||
|
assert validate_jsonfile("pyfunc_dataset_pipeline.json") is True
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
if remove_json_files:
|
||||||
|
delete_json_files("depr_pyfunc_dataset_pipeline")
|
||||||
|
|
||||||
|
|
||||||
|
def test_serdes_inter_mixed_map(remove_json_files=True):
|
||||||
|
"""
|
||||||
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipelines in which each map op has Python ops or C++ ops
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
|
"""
|
||||||
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
||||||
|
# The following map op uses Python ops
|
||||||
|
data1 = data1.map(operations=[py_vision.Decode(), py_vision.CenterCrop([24, 24])], input_columns=["image"])
|
||||||
|
# The following map op uses Python ops
|
||||||
|
data1 = data1.map(operations=[py_vision.ToTensor(), py_vision.ToPIL()], input_columns=["image"])
|
||||||
|
# The following map op uses C++ ops
|
||||||
|
data1 = data1.map(operations=[c_vision.HorizontalFlip(), c_vision.VerticalFlip()], input_columns=["image"])
|
||||||
|
# The following map op uses Python ops
|
||||||
|
data1 = data1.map(operations=[py_vision.ToPIL(), py_vision.FiveCrop((18, 22))], input_columns=["image"])
|
||||||
|
|
||||||
|
util_check_serialize_deserialize_file(data1, "depr_inter_mixed_map_pipeline", remove_json_files)
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
if remove_json_files:
|
||||||
|
delete_json_files("depr_inter_mixed_map_pipeline")
|
||||||
|
|
||||||
|
|
||||||
|
def test_serdes_intra_mixed_py2c_map(remove_json_files=True):
|
||||||
|
"""
|
||||||
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipelines in which each map op has a mix of Python ops
|
||||||
|
then C++ ops
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
|
"""
|
||||||
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
||||||
|
transforms_list = [py_vision.Decode(),
|
||||||
|
py_vision.CenterCrop([24, 24]),
|
||||||
|
py_vision.ToTensor(),
|
||||||
|
py_vision.Normalize([0.48, 0.45, 0.40], [0.22, 0.22, 0.22]),
|
||||||
|
c_vision.RandomHorizontalFlip(),
|
||||||
|
c_vision.VerticalFlip()]
|
||||||
|
data1 = data1.map(operations=transforms_list, input_columns=["image"])
|
||||||
|
data2 = util_check_serialize_deserialize_file(data1, "depr_intra_mixed_py2c_map_pipeline", False)
|
||||||
|
|
||||||
|
num_itr = 0
|
||||||
|
# Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
|
||||||
|
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||||
|
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||||
|
np.testing.assert_array_equal(item1['image'], item2['image'])
|
||||||
|
num_itr += 1
|
||||||
|
assert num_itr == 3
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
if remove_json_files:
|
||||||
|
delete_json_files("depr_intra_mixed_py2c_map_pipeline")
|
||||||
|
|
||||||
|
|
||||||
|
def test_serdes_intra_mixed_c2py_map(remove_json_files=True):
|
||||||
|
"""
|
||||||
|
Feature: Serialize and Deserialize Support
|
||||||
|
Description: Test serialize and deserialize on pipelines in which each map op has a mix of C++ ops
|
||||||
|
then Python ops
|
||||||
|
Expectation: Serialized versus Deserialized+reserialized pipeline output verified
|
||||||
|
"""
|
||||||
|
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
|
||||||
|
data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
|
||||||
|
transforms_list = [c_vision.Decode(),
|
||||||
|
c_vision.RandomSolarize((0, 127)),
|
||||||
|
py_vision.ToPIL(),
|
||||||
|
py_vision.CenterCrop([64, 64])]
|
||||||
|
data1 = data1.map(operations=transforms_list, input_columns=["image"])
|
||||||
|
data2 = util_check_serialize_deserialize_file(data1, "depr_intra_mixed_c2py_map_pipeline", False)
|
||||||
|
|
||||||
|
num_itr = 0
|
||||||
|
# Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
|
||||||
|
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||||
|
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||||
|
np.testing.assert_array_equal(item1['image'], item2['image'])
|
||||||
|
num_itr += 1
|
||||||
|
assert num_itr == 3
|
||||||
|
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
if remove_json_files:
|
||||||
|
delete_json_files("depr_intra_mixed_c2py_map_pipeline")
|
||||||
|
|
||||||
|
|
||||||
|
def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files):
|
||||||
|
"""
|
||||||
|
Utility function for testing serdes files. It is to check if a json file is indeed created with correct name
|
||||||
|
after serializing and if it remains the same after repeatedly saving and loading.
|
||||||
|
:param data_orig: original data pipeline to be serialized
|
||||||
|
:param filename: filename to be saved as json format
|
||||||
|
:param remove_json_files: whether to remove the json file after testing
|
||||||
|
:return: The data pipeline after serializing and deserializing using the original pipeline
|
||||||
|
"""
|
||||||
|
file1 = filename + ".json"
|
||||||
|
file2 = filename + "_1.json"
|
||||||
|
ds.serialize(data_orig, file1)
|
||||||
|
assert validate_jsonfile(file1) is True
|
||||||
|
assert validate_jsonfile("wrong_name.json") is False
|
||||||
|
|
||||||
|
data_changed = ds.deserialize(json_filepath=file1)
|
||||||
|
ds.serialize(data_changed, file2)
|
||||||
|
assert validate_jsonfile(file2) is True
|
||||||
|
assert filecmp.cmp(file1, file2, shallow=False)
|
||||||
|
|
||||||
|
# Remove the generated json file
|
||||||
|
if remove_json_files:
|
||||||
|
delete_json_files(filename)
|
||||||
|
return data_changed
|
||||||
|
|
||||||
|
|
||||||
|
def validate_jsonfile(filepath):
|
||||||
|
try:
|
||||||
|
file_exist = os.path.exists(filepath)
|
||||||
|
with open(filepath, 'r') as jfile:
|
||||||
|
loaded_json = json.load(jfile)
|
||||||
|
except IOError:
|
||||||
|
return False
|
||||||
|
return file_exist and isinstance(loaded_json, dict)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_json_files(filename):
|
||||||
|
file_list = glob.glob(filename + '.json') + glob.glob(filename + '_1.json')
|
||||||
|
for f in file_list:
|
||||||
|
try:
|
||||||
|
os.remove(f)
|
||||||
|
except IOError:
|
||||||
|
logger.info("Error while deleting: {}".format(f))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_serdes_pyvision()
|
||||||
|
test_serdes_pyfunc()
|
||||||
|
test_serdes_inter_mixed_map()
|
||||||
|
test_serdes_intra_mixed_py2c_map()
|
||||||
|
test_serdes_intra_mixed_c2py_map()
|
Loading…
Reference in New Issue