forked from mindspore-Ecosystem/mindspore
parent
56bd92b88f
commit
738ae2c78d
|
@ -422,7 +422,7 @@ PYBIND_REGISTER(
|
|||
PYBIND_REGISTER(RandomSolarizeOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomSolarizeOp, TensorOp, std::shared_ptr<RandomSolarizeOp>>(*m,
|
||||
"RandomSolarizeOp")
|
||||
.def(py::init<uint8_t, uint8_t>());
|
||||
.def(py::init<std::vector<uint8_t>>());
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -241,8 +241,8 @@ std::shared_ptr<RandomRotationOperation> RandomRotation(std::vector<float> degre
|
|||
}
|
||||
|
||||
// Function to create RandomSolarizeOperation.
|
||||
std::shared_ptr<RandomSolarizeOperation> RandomSolarize(uint8_t threshold_min, uint8_t threshold_max) {
|
||||
auto op = std::make_shared<RandomSolarizeOperation>(threshold_min, threshold_max);
|
||||
std::shared_ptr<RandomSolarizeOperation> RandomSolarize(std::vector<uint8_t> threshold) {
|
||||
auto op = std::make_shared<RandomSolarizeOperation>(threshold);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
|
@ -811,19 +811,22 @@ std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() {
|
|||
}
|
||||
|
||||
// RandomSolarizeOperation.
|
||||
RandomSolarizeOperation::RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max)
|
||||
: threshold_min_(threshold_min), threshold_max_(threshold_max) {}
|
||||
RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold) : threshold_(threshold) {}
|
||||
|
||||
bool RandomSolarizeOperation::ValidateParams() {
|
||||
if (threshold_max_ < threshold_min_) {
|
||||
MS_LOG(ERROR) << "RandomSolarize: threshold_max must be greater or equal to threshold_min";
|
||||
if (threshold_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomSolarize: threshold vector has incorrect size: " << threshold_.size();
|
||||
return false;
|
||||
}
|
||||
if (threshold_.at(0) > threshold_.at(1)) {
|
||||
MS_LOG(ERROR) << "RandomSolarize: threshold must be passed in a min, max format";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomSolarizeOperation::Build() {
|
||||
std::shared_ptr<RandomSolarizeOp> tensor_op = std::make_shared<RandomSolarizeOp>(threshold_min_, threshold_max_);
|
||||
std::shared_ptr<RandomSolarizeOp> tensor_op = std::make_shared<RandomSolarizeOp>(threshold_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
|
|
|
@ -249,10 +249,9 @@ std::shared_ptr<RandomSharpnessOperation> RandomSharpness(std::vector<float> deg
|
|||
|
||||
/// \brief Function to create a RandomSolarize TensorOperation.
|
||||
/// \notes Invert pixels within specified range. If min=max, then it inverts all pixel above that threshold
|
||||
/// \param[in] threshold_min - lower limit
|
||||
/// \param[in] threshold_max - upper limit
|
||||
/// \param[in] threshold - a vector with two elements specifying the pixel range to invert.
|
||||
/// \return Shared pointer to the current TensorOperation.
|
||||
std::shared_ptr<RandomSolarizeOperation> RandomSolarize(uint8_t threshold_min = 0, uint8_t threshold_max = 255);
|
||||
std::shared_ptr<RandomSolarizeOperation> RandomSolarize(std::vector<uint8_t> threshold = {0, 255});
|
||||
|
||||
/// \brief Function to create a RandomVerticalFlip TensorOperation.
|
||||
/// \notes Tensor operation to perform random vertical flip.
|
||||
|
@ -657,7 +656,7 @@ class SwapRedBlueOperation : public TensorOperation {
|
|||
|
||||
class RandomSolarizeOperation : public TensorOperation {
|
||||
public:
|
||||
explicit RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max);
|
||||
explicit RandomSolarizeOperation(std::vector<uint8_t> threshold);
|
||||
|
||||
~RandomSolarizeOperation() = default;
|
||||
|
||||
|
@ -666,8 +665,7 @@ class RandomSolarizeOperation : public TensorOperation {
|
|||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
uint8_t threshold_min_;
|
||||
uint8_t threshold_max_;
|
||||
std::vector<uint8_t> threshold_;
|
||||
};
|
||||
} // namespace vision
|
||||
} // namespace api
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/kernels/image/random_solarize_op.h"
|
||||
#include "minddata/dataset/kernels/image/solarize_op.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
|
@ -24,6 +26,9 @@ namespace dataset {
|
|||
|
||||
Status RandomSolarizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
|
||||
uint8_t threshold_min_ = threshold_[0], threshold_max_ = threshold_[1];
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(threshold_min_ <= threshold_max_,
|
||||
"threshold_min must be smaller or equal to threshold_max.");
|
||||
|
||||
|
@ -35,7 +40,8 @@ Status RandomSolarizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shar
|
|||
threshold_min = threshold_max;
|
||||
threshold_max = temp;
|
||||
}
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold_min, threshold_max));
|
||||
std::vector<uint8_t> inputs = {threshold_min, threshold_max};
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(inputs));
|
||||
return op->Compute(input, output);
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
|
@ -31,10 +32,7 @@ namespace dataset {
|
|||
class RandomSolarizeOp : public SolarizeOp {
|
||||
public:
|
||||
// Pick a random threshold value to solarize the image with
|
||||
explicit RandomSolarizeOp(uint8_t threshold_min = 0, uint8_t threshold_max = 255)
|
||||
: threshold_min_(threshold_min), threshold_max_(threshold_max) {
|
||||
rnd_.seed(GetSeed());
|
||||
}
|
||||
explicit RandomSolarizeOp(std::vector<uint8_t> threshold = {0, 255}) : threshold_(threshold) { rnd_.seed(GetSeed()); }
|
||||
|
||||
~RandomSolarizeOp() = default;
|
||||
|
||||
|
@ -43,8 +41,7 @@ class RandomSolarizeOp : public SolarizeOp {
|
|||
std::string Name() const override { return kRandomSolarizeOp; }
|
||||
|
||||
private:
|
||||
uint8_t threshold_min_;
|
||||
uint8_t threshold_max_;
|
||||
std::vector<uint8_t> threshold_;
|
||||
std::mt19937 rnd_;
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -27,6 +27,8 @@ const uint8_t kPixelValue = 255;
|
|||
Status SolarizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
|
||||
uint8_t threshold_min_ = threshold_[0], threshold_max_ = threshold_[1];
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(threshold_min_ <= threshold_max_,
|
||||
"threshold_min must be smaller or equal to threshold_max.");
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
|
@ -28,8 +29,7 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
class SolarizeOp : public TensorOp {
|
||||
public:
|
||||
explicit SolarizeOp(uint8_t threshold_min = 0, uint8_t threshold_max = 255)
|
||||
: threshold_min_(threshold_min), threshold_max_(threshold_max) {}
|
||||
explicit SolarizeOp(std::vector<uint8_t> threshold = {0, 255}) : threshold_(threshold) {}
|
||||
|
||||
~SolarizeOp() = default;
|
||||
|
||||
|
@ -38,8 +38,7 @@ class SolarizeOp : public TensorOp {
|
|||
std::string Name() const override { return kSolarizeOp; }
|
||||
|
||||
private:
|
||||
uint8_t threshold_min_;
|
||||
uint8_t threshold_max_;
|
||||
std::vector<uint8_t> threshold_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1045,4 +1045,4 @@ class RandomSolarize(cde.RandomSolarizeOp):
|
|||
|
||||
@check_random_solarize
|
||||
def __init__(self, threshold=(0, 255)):
|
||||
super().__init__(*threshold)
|
||||
super().__init__(threshold)
|
||||
|
|
|
@ -1109,31 +1109,23 @@ TEST_F(MindDataTestPipeline, TestUniformAugWithOps) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomSolarize) {
|
||||
TEST_F(MindDataTestPipeline, TestRandomSolarizeSucess1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSolarize.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
int32_t repeat_num = 2;
|
||||
ds = ds->Repeat(repeat_num);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> random_solarize =
|
||||
mindspore::dataset::api::vision::RandomSolarize(23, 23); // vision::RandomSolarize();
|
||||
std::vector<uint8_t> threshold = {10, 100};
|
||||
std::shared_ptr<TensorOperation> random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold);
|
||||
EXPECT_NE(random_solarize, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({random_solarize});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 1;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
|
@ -1151,8 +1143,61 @@ TEST_F(MindDataTestPipeline, TestRandomSolarize) {
|
|||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 20);
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomSolarizeSucess2) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSolarize with default params.";
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> random_solarize = mindspore::dataset::api::vision::RandomSolarize();
|
||||
EXPECT_NE(random_solarize, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({random_solarize});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomSolarizeFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSolarize with invalid params.";
|
||||
std::vector<uint8_t> threshold = {13, 1};
|
||||
std::shared_ptr<TensorOperation> random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold);
|
||||
EXPECT_EQ(random_solarize, nullptr);
|
||||
|
||||
threshold = {1, 2, 3};
|
||||
random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold);
|
||||
EXPECT_EQ(random_solarize, nullptr);
|
||||
|
||||
threshold = {1};
|
||||
random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold);
|
||||
EXPECT_EQ(random_solarize, nullptr);
|
||||
}
|
||||
|
|
|
@ -38,7 +38,9 @@ TEST_F(MindDataTestRandomSolarizeOp, TestOp1) {
|
|||
uint32_t curr_seed = GlobalContext::config_manager()->seed();
|
||||
GlobalContext::config_manager()->set_seed(0);
|
||||
|
||||
std::unique_ptr<RandomSolarizeOp> op(new RandomSolarizeOp(100, 100));
|
||||
std::vector<uint8_t> threshold = {100, 100};
|
||||
std::unique_ptr<RandomSolarizeOp> op(new RandomSolarizeOp(threshold));
|
||||
|
||||
EXPECT_TRUE(op->OneToOne());
|
||||
Status s = op->Compute(input_tensor_, &output_tensor_);
|
||||
|
||||
|
|
|
@ -74,7 +74,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp3) {
|
|||
MS_LOG(INFO) << "Doing testSolarizeOp3 - Pass in only threshold_min parameter";
|
||||
|
||||
// unsigned int threshold = 128;
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(1));
|
||||
std::vector<uint8_t> threshold ={1, 255};
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold));
|
||||
|
||||
std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255};
|
||||
std::vector<uint8_t> expected_output_vector = {252, 251, 196, 45, 0};
|
||||
|
@ -98,8 +99,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp3) {
|
|||
TEST_F(MindDataTestSolarizeOp, TestOp4) {
|
||||
MS_LOG(INFO) << "Doing testSolarizeOp4 - Pass in both threshold parameters.";
|
||||
|
||||
// unsigned int threshold = 128;
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(1, 230));
|
||||
std::vector<uint8_t> threshold ={1, 230};
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold));
|
||||
|
||||
std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255};
|
||||
std::vector<uint8_t> expected_output_vector = {252, 251, 196, 45, 255};
|
||||
|
@ -123,8 +124,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp4) {
|
|||
TEST_F(MindDataTestSolarizeOp, TestOp5) {
|
||||
MS_LOG(INFO) << "Doing testSolarizeOp5 - Rank 2 input tensor.";
|
||||
|
||||
// unsigned int threshold = 128;
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(1, 230));
|
||||
std::vector<uint8_t> threshold ={1, 230};
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold));
|
||||
|
||||
std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255};
|
||||
std::vector<uint8_t> expected_output_vector = {252, 251, 196, 45, 255};
|
||||
|
@ -149,7 +150,8 @@ TEST_F(MindDataTestSolarizeOp, TestOp5) {
|
|||
TEST_F(MindDataTestSolarizeOp, TestOp6) {
|
||||
MS_LOG(INFO) << "Doing testSolarizeOp6 - Bad Input.";
|
||||
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(10, 1));
|
||||
std::vector<uint8_t> threshold ={10, 1};
|
||||
std::unique_ptr<SolarizeOp> op(new SolarizeOp(threshold));
|
||||
|
||||
std::vector<uint8_t> test_vector = {3, 4, 59, 210, 255};
|
||||
std::shared_ptr<Tensor> test_input_tensor;
|
||||
|
|
Binary file not shown.
|
@ -17,68 +17,89 @@ Testing RandomSolarizeOp op in DE
|
|||
"""
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers
|
||||
from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers, \
|
||||
visualize_one_channel_dataset
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
MNIST_DATA_DIR = "../data/dataset/testMnistData"
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
|
||||
def test_random_solarize_op(threshold=None, plot=False):
|
||||
def test_random_solarize_op(threshold=(10, 150), plot=False, run_golden=True):
|
||||
"""
|
||||
Test RandomSolarize
|
||||
"""
|
||||
logger.info("Test RandomSolarize")
|
||||
|
||||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
decode_op = vision.Decode()
|
||||
|
||||
original_seed = config_get_set_seed(0)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
if threshold is None:
|
||||
solarize_op = vision.RandomSolarize()
|
||||
else:
|
||||
solarize_op = vision.RandomSolarize(threshold)
|
||||
|
||||
data1 = data1.map(input_columns=["image"], operations=decode_op)
|
||||
data1 = data1.map(input_columns=["image"], operations=solarize_op)
|
||||
|
||||
# Second dataset
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
data2 = data2.map(input_columns=["image"], operations=decode_op)
|
||||
|
||||
if run_golden:
|
||||
filename = "random_solarize_01_result.npz"
|
||||
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
image_solarized = []
|
||||
image = []
|
||||
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
image_solarized.append(item1["image"].copy())
|
||||
image.append(item2["image"].copy())
|
||||
if plot:
|
||||
visualize_list(image, image_solarized)
|
||||
|
||||
|
||||
def test_random_solarize_md5():
|
||||
"""
|
||||
Test RandomSolarize
|
||||
"""
|
||||
logger.info("Test RandomSolarize")
|
||||
original_seed = config_get_set_seed(0)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
decode_op = vision.Decode()
|
||||
random_solarize_op = vision.RandomSolarize((10, 150))
|
||||
data1 = data1.map(input_columns=["image"], operations=decode_op)
|
||||
data1 = data1.map(input_columns=["image"], operations=random_solarize_op)
|
||||
# Compare with expected md5 from images
|
||||
filename = "random_solarize_01_result.npz"
|
||||
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
# Restore config setting
|
||||
ds.config.set_seed(original_seed)
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_random_solarize_mnist(plot=False, run_golden=True):
|
||||
"""
|
||||
Test RandomSolarize op with MNIST dataset (Grayscale images)
|
||||
"""
|
||||
|
||||
mnist_1 = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
|
||||
mnist_2 = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
|
||||
mnist_2 = mnist_2.map(input_columns="image", operations=vision.RandomSolarize((0, 255)))
|
||||
|
||||
images = []
|
||||
images_trans = []
|
||||
labels = []
|
||||
|
||||
for _, (data_orig, data_trans) in enumerate(zip(mnist_1, mnist_2)):
|
||||
image_orig, label_orig = data_orig
|
||||
image_trans, _ = data_trans
|
||||
images.append(image_orig)
|
||||
labels.append(label_orig)
|
||||
images_trans.append(image_trans)
|
||||
|
||||
if plot:
|
||||
visualize_one_channel_dataset(images, images_trans, labels)
|
||||
|
||||
if run_golden:
|
||||
filename = "random_solarize_02_result.npz"
|
||||
save_and_check_md5(mnist_2, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_random_solarize_errors():
|
||||
"""
|
||||
Test that RandomSolarize errors with bad input
|
||||
|
@ -105,8 +126,8 @@ def test_random_solarize_errors():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_random_solarize_op((100, 100), plot=True)
|
||||
test_random_solarize_op((12, 120), plot=True)
|
||||
test_random_solarize_op(plot=True)
|
||||
test_random_solarize_op((10, 150), plot=True, run_golden=True)
|
||||
test_random_solarize_op((12, 120), plot=True, run_golden=False)
|
||||
test_random_solarize_op(plot=True, run_golden=False)
|
||||
test_random_solarize_mnist(plot=True, run_golden=True)
|
||||
test_random_solarize_errors()
|
||||
test_random_solarize_md5()
|
||||
|
|
Loading…
Reference in New Issue