forked from mindspore-Ecosystem/mindspore
[feat] [assistant] [I40GZE] add new data operator RandomAdjustSharpness
This commit is contained in:
parent
13a48747a8
commit
d072a4e4bf
|
@ -36,6 +36,7 @@
|
|||
#include "minddata/dataset/kernels/ir/vision/normalize_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/normalize_pad_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/pad_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_adjust_sharpness_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_affine_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_auto_contrast_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_color_adjust_ir.h"
|
||||
|
@ -275,6 +276,18 @@ PYBIND_REGISTER(PadOperation, 1, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomAdjustSharpnessOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<vision::RandomAdjustSharpnessOperation, TensorOperation,
|
||||
std::shared_ptr<vision::RandomAdjustSharpnessOperation>>(
|
||||
*m, "RandomAdjustSharpnessOperation")
|
||||
.def(py::init([](float degree, float prob) {
|
||||
auto random_adjust_sharpness =
|
||||
std::make_shared<vision::RandomAdjustSharpnessOperation>(degree, prob);
|
||||
THROW_IF_ERROR(random_adjust_sharpness->ValidateParams());
|
||||
return random_adjust_sharpness;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
RandomAffineOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<vision::RandomAffineOperation, TensorOperation, std::shared_ptr<vision::RandomAffineOperation>>(
|
||||
|
|
|
@ -40,6 +40,7 @@
|
|||
#include "minddata/dataset/kernels/ir/vision/normalize_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/normalize_pad_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/pad_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_adjust_sharpness_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_affine_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_auto_contrast_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/random_color_adjust_ir.h"
|
||||
|
@ -450,6 +451,19 @@ Pad::Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderTy
|
|||
std::shared_ptr<TensorOperation> Pad::Parse() {
|
||||
return std::make_shared<PadOperation>(data_->padding_, data_->fill_value_, data_->padding_mode_);
|
||||
}
|
||||
|
||||
// RandomAdjustSharpness Transform Operation.
|
||||
struct RandomAdjustSharpness::Data {
|
||||
Data(float degree, float prob) : degree_(degree), probability_(prob) {}
|
||||
float degree_;
|
||||
float probability_;
|
||||
};
|
||||
|
||||
RandomAdjustSharpness::RandomAdjustSharpness(float degree, float prob) : data_(std::make_shared<Data>(degree, prob)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> RandomAdjustSharpness::Parse() {
|
||||
return std::make_shared<RandomAdjustSharpnessOperation>(data_->degree_, data_->probability_);
|
||||
}
|
||||
#endif // not ENABLE_ANDROID
|
||||
|
||||
// RandomAffine Transform Operation.
|
||||
|
|
|
@ -348,6 +348,28 @@ class RandomAutoContrast final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Randomly adjust the sharpness of the input image with a given probability.
|
||||
class RandomAdjustSharpness final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] degree A float representing sharpness adjustment degree, which must be non negative.
|
||||
/// \param[in] prob A float representing the probability of the image being sharpness adjusted, which
|
||||
/// must in range of [0, 1] (default=0.5).
|
||||
explicit RandomAdjustSharpness(float degree, float prob = 0.5);
|
||||
|
||||
/// \brief Destructor.
|
||||
~RandomAdjustSharpness() = default;
|
||||
|
||||
protected:
|
||||
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Blend an image with its grayscale version with random weights
|
||||
/// t and 1 - t generated from a given range. If the range is trivial
|
||||
/// then the weights are determinate and t equals to the bound of the interval.
|
||||
|
|
|
@ -28,6 +28,7 @@ add_library(kernels-image OBJECT
|
|||
normalize_pad_op.cc
|
||||
pad_op.cc
|
||||
posterize_op.cc
|
||||
random_adjust_sharpness_op.cc
|
||||
random_affine_op.cc
|
||||
random_auto_contrast_op.cc
|
||||
random_color_adjust_op.cc
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/kernels/image/random_adjust_sharpness_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
const float RandomAdjustSharpnessOp::kDefProbability = 0.5;
|
||||
|
||||
Status RandomAdjustSharpnessOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
|
||||
if (distribution_(rnd_)) {
|
||||
return SharpnessOp::Compute(input, output);
|
||||
}
|
||||
*output = input;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_ADJUST_SHARPNESS_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_ADJUST_SHARPNESS_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/kernels/image/sharpness_op.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class RandomAdjustSharpnessOp : public SharpnessOp {
|
||||
public:
|
||||
// Default values, also used by python_bindings.cc
|
||||
static const float kDefProbability;
|
||||
|
||||
explicit RandomAdjustSharpnessOp(float degree, float prob = kDefProbability)
|
||||
: SharpnessOp(degree), distribution_(prob) {
|
||||
is_deterministic_ = false;
|
||||
rnd_.seed(GetSeed());
|
||||
}
|
||||
|
||||
~RandomAdjustSharpnessOp() override = default;
|
||||
|
||||
// Provide stream operator for displaying it
|
||||
friend std::ostream &operator<<(std::ostream &out, const RandomAdjustSharpnessOp &so) {
|
||||
so.Print(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kRandomAdjustSharpnessOp; }
|
||||
|
||||
private:
|
||||
std::mt19937 rnd_;
|
||||
std::bernoulli_distribution distribution_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_ADJUST_SHARPNESS_OP_H_
|
|
@ -21,6 +21,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
|
|||
normalize_ir.cc
|
||||
normalize_pad_ir.cc
|
||||
pad_ir.cc
|
||||
random_adjust_sharpness_ir.cc
|
||||
random_affine_ir.cc
|
||||
random_auto_contrast_ir.cc
|
||||
random_color_adjust_ir.cc
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/kernels/ir/vision/random_adjust_sharpness_ir.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/random_adjust_sharpness_op.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/kernels/ir/validators.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace vision {
|
||||
#ifndef ENABLE_ANDROID
|
||||
// RandomAdjustSharpnessOperation
|
||||
RandomAdjustSharpnessOperation::RandomAdjustSharpnessOperation(float degree, float prob)
|
||||
: degree_(degree), probability_(prob) {}
|
||||
|
||||
RandomAdjustSharpnessOperation::~RandomAdjustSharpnessOperation() = default;
|
||||
|
||||
std::string RandomAdjustSharpnessOperation::Name() const { return kRandomAdjustSharpnessOperation; }
|
||||
|
||||
Status RandomAdjustSharpnessOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("RandomAdjustSharpness", "degree", degree_));
|
||||
RETURN_IF_NOT_OK(ValidateProbability("RandomAdjustSharpness", probability_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomAdjustSharpnessOperation::Build() {
|
||||
std::shared_ptr<RandomAdjustSharpnessOp> tensor_op = std::make_shared<RandomAdjustSharpnessOp>(degree_, probability_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status RandomAdjustSharpnessOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["degree"] = degree_;
|
||||
args["prob"] = probability_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RandomAdjustSharpnessOperation::from_json(nlohmann::json op_params,
|
||||
std::shared_ptr<TensorOperation> *operation) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("degree") != op_params.end(), "Failed to find degree");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("prob") != op_params.end(), "Failed to find prob");
|
||||
float degree = op_params["degree"];
|
||||
float prob = op_params["prob"];
|
||||
*operation = std::make_shared<vision::RandomAdjustSharpnessOperation>(degree, prob);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_ADJUST_SHARPNESS_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_ADJUST_SHARPNESS_IR_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace vision {
|
||||
|
||||
constexpr char kRandomAdjustSharpnessOperation[] = "RandomAdjustSharpness";
|
||||
|
||||
class RandomAdjustSharpnessOperation : public TensorOperation {
|
||||
public:
|
||||
RandomAdjustSharpnessOperation(float degree, float prob);
|
||||
|
||||
~RandomAdjustSharpnessOperation();
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override;
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
||||
|
||||
private:
|
||||
float degree_;
|
||||
float probability_;
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_ADJUST_SHARPNESS_IR_H_
|
|
@ -79,6 +79,7 @@ constexpr char kMixUpBatchOp[] = "MixUpBatchOp";
|
|||
constexpr char kNormalizeOp[] = "NormalizeOp";
|
||||
constexpr char kNormalizePadOp[] = "NormalizePadOp";
|
||||
constexpr char kPadOp[] = "PadOp";
|
||||
constexpr char kRandomAdjustSharpnessOp[] = "RandomAdjustSharpnessOp";
|
||||
constexpr char kRandomAffineOp[] = "RandomAffineOp";
|
||||
constexpr char kRandomAutoContrastOp[] = "RandomAutoContrastOp";
|
||||
constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp";
|
||||
|
|
|
@ -52,6 +52,7 @@ from .validators import check_prob, check_crop, check_center_crop, check_resize_
|
|||
check_mix_up_batch_c, check_normalize_c, check_normalizepad_c, check_random_crop, check_random_color_adjust, \
|
||||
check_random_rotation, check_range, check_resize, check_rescale, check_pad, check_cutout, \
|
||||
check_uniform_augment_cpp, check_convert_color, check_random_resize_crop, check_random_auto_contrast, \
|
||||
check_random_adjust_sharpness, \
|
||||
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
|
||||
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \
|
||||
check_cut_mix_batch_c, check_posterize, check_gaussian_blur, check_rotate, check_slice_patches, check_adjust_gamma
|
||||
|
@ -671,6 +672,32 @@ class Pad(ImageTensorOperation):
|
|||
return cde.PadOperation(self.padding, self.fill_value, DE_C_BORDER_TYPE[self.padding_mode])
|
||||
|
||||
|
||||
class RandomAdjustSharpness(ImageTensorOperation):
|
||||
"""
|
||||
Randomly adjust the sharpness of the input image with a given probability.
|
||||
|
||||
Args:
|
||||
degree (float): Sharpness adjustment degree, which must be non negative.
|
||||
Degree of 0.0 gives a blurred image, degree of 1.0 gives the original image,
|
||||
and degree of 2.0 increases the sharpness by a factor of 2.
|
||||
prob (float, optional): Probability of the image being sharpness adjusted, which
|
||||
must be in range of [0, 1] (default=0.5).
|
||||
|
||||
Examples:
|
||||
>>> transforms_list = [c_vision.Decode(), c_vision.RandomAdjustSharpness(2.0, 0.5)]
|
||||
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
|
||||
... input_columns=["image"])
|
||||
"""
|
||||
|
||||
@check_random_adjust_sharpness
|
||||
def __init__(self, degree, prob=0.5):
|
||||
self.prob = prob
|
||||
self.degree = degree
|
||||
|
||||
def parse(self):
|
||||
return cde.RandomAdjustSharpnessOperation(self.degree, self.prob)
|
||||
|
||||
|
||||
class RandomAffine(ImageTensorOperation):
|
||||
"""
|
||||
Apply Random affine transformation to the input image.
|
||||
|
|
|
@ -296,6 +296,22 @@ def check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts)
|
|||
check_value(max_attempts, (1, FLOAT_MAX_INTEGER))
|
||||
|
||||
|
||||
def check_random_adjust_sharpness(method):
|
||||
"""Wrapper method to check the parameters of RandomAdjustSharpness."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[degree, prob], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(degree, (float, int), "degree")
|
||||
check_non_negative_float32(degree, "degree")
|
||||
type_check(prob, (float, int), "prob")
|
||||
check_value(prob, [0., 1.], "prob")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_random_resize_crop(method):
|
||||
"""A wrapper that wraps a parameter checker around the original function(random resize crop operation)."""
|
||||
|
||||
|
|
|
@ -1270,129 +1270,3 @@ TEST_F(MindDataTestPipeline, TestConvertColorFail) {
|
|||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomInvert) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomInvert.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_invert_op = vision::RandomInvert(0.5);
|
||||
|
||||
ds = ds->Map({random_invert_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomInvertInvalidProb) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomInvertInvalidProb.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_invert_op = vision::RandomInvert(1.5);
|
||||
|
||||
ds = ds->Map({random_invert_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomAutoContrast) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAutoContrast.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_auto_contrast_op = vision::RandomAutoContrast(1.0, {0, 255}, 0.5);
|
||||
|
||||
ds = ds->Map({random_auto_contrast_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomAutoContrastInvalidProb) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAutoContrastInvalidProb.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_auto_contrast_op = vision::RandomAutoContrast(0.0, {}, 1.5);
|
||||
|
||||
ds = ds->Map({random_auto_contrast_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomAutoContrastInvalidCutoff) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAutoContrastInvalidCutoff.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_auto_contrast_op = vision::RandomAutoContrast(-2.0, {}, 0.5);
|
||||
|
||||
ds = ds->Map({random_auto_contrast_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomAutoContrastInvalidIgnore) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAutoContrastInvalidCutoff.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_auto_contrast_op = vision::RandomAutoContrast(1.0, {10, 256}, 0.5);
|
||||
|
||||
ds = ds->Map({random_auto_contrast_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
|
|
@ -396,3 +396,193 @@ TEST_F(MindDataTestPipeline, TestRandomEqualizeInvalidProb) {
|
|||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomInvert) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomInvert.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_invert_op = vision::RandomInvert(0.5);
|
||||
|
||||
ds = ds->Map({random_invert_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomInvertInvalidProb) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomInvertInvalidProb.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_invert_op = vision::RandomInvert(1.5);
|
||||
|
||||
ds = ds->Map({random_invert_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomAutoContrast) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAutoContrast.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_auto_contrast_op = vision::RandomAutoContrast(1.0, {0, 255}, 0.5);
|
||||
|
||||
ds = ds->Map({random_auto_contrast_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomAutoContrastInvalidProb) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAutoContrastInvalidProb.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_auto_contrast_op = vision::RandomAutoContrast(0.0, {}, 1.5);
|
||||
|
||||
ds = ds->Map({random_auto_contrast_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomAutoContrastInvalidCutoff) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAutoContrastInvalidCutoff.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_auto_contrast_op = vision::RandomAutoContrast(-2.0, {}, 0.5);
|
||||
|
||||
ds = ds->Map({random_auto_contrast_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomAutoContrastInvalidIgnore) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAutoContrastInvalidCutoff.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_auto_contrast_op = vision::RandomAutoContrast(1.0, {10, 256}, 0.5);
|
||||
|
||||
ds = ds->Map({random_auto_contrast_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomAdjustSharpness) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAdjustSharpness.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_adjust_sharpness_op = vision::RandomAdjustSharpness(2.0, 0.5);
|
||||
|
||||
ds = ds->Map({random_adjust_sharpness_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomAdjustSharpnessInvalidProb) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAdjustSharpnessInvalidProb.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_adjust_sharpness_op = vision::RandomAdjustSharpness(2.0, 1.5);
|
||||
|
||||
ds = ds->Map({random_adjust_sharpness_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomAdjustSharpnessInvalidDegree) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomAdjustSharpnessInvalidProb.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto random_adjust_sharpness_op = vision::RandomAdjustSharpness(-2.0, 0.3);
|
||||
|
||||
ds = ds->Map({random_adjust_sharpness_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
|
|
@ -1140,3 +1140,17 @@ TEST_F(MindDataTestExecute, TestRandomEqualizeEager) {
|
|||
Status rc = transform(image, &image);
|
||||
EXPECT_EQ(rc, Status::OK());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestRandomAdjustSharpnessEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestRandomAdjustSharpnessEager.";
|
||||
// Read images
|
||||
auto image = ReadFileToTensor("data/dataset/apple.jpg");
|
||||
|
||||
// Transform params
|
||||
auto decode = vision::Decode();
|
||||
auto random_adjust_sharpness_op = vision::RandomAdjustSharpness(2.0, 0.6);
|
||||
|
||||
auto transform = Execute({decode, random_adjust_sharpness_op});
|
||||
Status rc = transform(image, &image);
|
||||
EXPECT_EQ(rc, Status::OK());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing RandomAdjustSharpness in DE
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as c_vision
|
||||
from mindspore import log as logger
|
||||
from util import visualize_list, visualize_image, diff_mse
|
||||
|
||||
image_file = "../data/dataset/testImageNetData/train/class1/1_1.jpg"
|
||||
data_dir = "../data/dataset/testImageNetData/train/"
|
||||
|
||||
|
||||
def test_random_adjust_sharpness_pipeline(plot=False):
|
||||
"""
|
||||
Test RandomAdjustSharpness pipeline
|
||||
"""
|
||||
logger.info("Test RandomAdjustSharpness pipeline")
|
||||
|
||||
# Original Images
|
||||
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
transforms_original = [c_vision.Decode(), c_vision.Resize(size=[224, 224])]
|
||||
ds_original = data_set.map(operations=transforms_original, input_columns="image")
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = image.asnumpy()
|
||||
else:
|
||||
images_original = np.append(images_original,
|
||||
image.asnumpy(),
|
||||
axis=0)
|
||||
|
||||
# Randomly Sharpness Adjusted Images
|
||||
data_set1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
transform_random_adjust_sharpness = [c_vision.Decode(),
|
||||
c_vision.Resize(size=[224, 224]),
|
||||
c_vision.RandomAdjustSharpness(2.0, 0.6)]
|
||||
ds_random_adjust_sharpness = data_set1.map(operations=transform_random_adjust_sharpness, input_columns="image")
|
||||
ds_random_adjust_sharpness = ds_random_adjust_sharpness.batch(512)
|
||||
for idx, (image, _) in enumerate(ds_random_adjust_sharpness):
|
||||
if idx == 0:
|
||||
images_random_adjust_sharpness = image.asnumpy()
|
||||
else:
|
||||
images_random_adjust_sharpness = np.append(images_random_adjust_sharpness,
|
||||
image.asnumpy(),
|
||||
axis=0)
|
||||
if plot:
|
||||
visualize_list(images_original, images_random_adjust_sharpness)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_random_adjust_sharpness[i], images_original[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
|
||||
def test_random_adjust_sharpness_eager():
|
||||
"""
|
||||
Test RandomAdjustSharpness eager.
|
||||
"""
|
||||
img = np.fromfile(image_file, dtype=np.uint8)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
|
||||
img = c_vision.Decode()(img)
|
||||
img_sharped = c_vision.RandomSharpness((2.0, 2.0))(img)
|
||||
img_random_sharped = c_vision.RandomAdjustSharpness(2.0, 1.0)(img)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img_random_sharped), img_random_sharped.shape))
|
||||
|
||||
assert img_random_sharped.all() == img_sharped.all()
|
||||
|
||||
|
||||
def test_random_adjust_sharpness_comp(plot=False):
|
||||
"""
|
||||
Test RandomAdjustSharpness op compared with Sharpness op.
|
||||
"""
|
||||
random_adjust_sharpness_op = c_vision.RandomAdjustSharpness(degree=2.0, prob=1.0)
|
||||
sharpness_op = c_vision.RandomSharpness((2.0, 2.0))
|
||||
|
||||
dataset1 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
|
||||
for item in dataset1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
image = item['image']
|
||||
dataset1.map(operations=random_adjust_sharpness_op, input_columns=['image'])
|
||||
dataset2 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
|
||||
dataset2.map(operations=sharpness_op, input_columns=['image'])
|
||||
|
||||
for item1, item2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||
dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
image_random_sharpness = item1['image']
|
||||
image_sharpness = item2['image']
|
||||
|
||||
mse = diff_mse(image_sharpness, image_random_sharpness)
|
||||
assert mse == 0
|
||||
logger.info("mse: {}".format(mse))
|
||||
if plot:
|
||||
visualize_image(image, image_random_sharpness, mse, image_sharpness)
|
||||
|
||||
|
||||
def test_random_adjust_sharpness_invalid_prob():
|
||||
"""
|
||||
Test invalid prob. prob out of range.
|
||||
"""
|
||||
logger.info("test_random_adjust_sharpness_invalid_prob")
|
||||
dataset = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
|
||||
try:
|
||||
random_adjust_sharpness_op = c_vision.RandomAdjustSharpness(2.0, 1.5)
|
||||
dataset = dataset.map(operations=random_adjust_sharpness_op, input_columns=['image'])
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input prob is not within the required interval of [0.0, 1.0]." in str(e)
|
||||
|
||||
|
||||
def test_random_adjust_sharpness_invalid_degree():
|
||||
"""
|
||||
Test invalid prob. prob out of range.
|
||||
"""
|
||||
logger.info("test_random_adjust_sharpness_invalid_prob")
|
||||
dataset = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
|
||||
try:
|
||||
random_adjust_sharpness_op = c_vision.RandomAdjustSharpness(-1.0, 1.5)
|
||||
dataset = dataset.map(operations=random_adjust_sharpness_op, input_columns=['image'])
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "interval" in str(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_random_adjust_sharpness_pipeline(plot=True)
|
||||
test_random_adjust_sharpness_eager()
|
||||
test_random_adjust_sharpness_comp(plot=True)
|
||||
test_random_adjust_sharpness_invalid_prob()
|
||||
test_random_adjust_sharpness_invalid_degree()
|
|
@ -93,13 +93,13 @@ def test_random_auto_contrast_comp(plot=False):
|
|||
auto_contrast_op = c_vision.AutoContrast()
|
||||
|
||||
dataset1 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
|
||||
for item in dataset1.create_dict_iterator(output_numpy=True):
|
||||
for item in dataset1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
image = item['image']
|
||||
dataset1.map(operations=random_auto_contrast_op, input_columns=['image'])
|
||||
dataset2 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
|
||||
dataset2.map(operations=auto_contrast_op, input_columns=['image'])
|
||||
for item1, item2 in zip(dataset1.create_dict_iterator(output_numpy=True),
|
||||
dataset2.create_dict_iterator(output_numpy=True)):
|
||||
for item1, item2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||
dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
image_random_auto_contrast = item1['image']
|
||||
image_auto_contrast = item2['image']
|
||||
|
||||
|
|
|
@ -91,13 +91,13 @@ def test_random_invert_comp(plot=False):
|
|||
invert_op = Invert()
|
||||
|
||||
dataset1 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
|
||||
for item in dataset1.create_dict_iterator(output_numpy=True):
|
||||
for item in dataset1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
image = item['image']
|
||||
dataset1.map(operations=random_invert_op, input_columns=['image'])
|
||||
dataset2 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
|
||||
dataset2.map(operations=invert_op, input_columns=['image'])
|
||||
for item1, item2 in zip(dataset1.create_dict_iterator(output_numpy=True),
|
||||
dataset2.create_dict_iterator(output_numpy=True)):
|
||||
for item1, item2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
||||
dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
image_random_inverted = item1['image']
|
||||
image_inverted = item2['image']
|
||||
|
||||
|
|
Loading…
Reference in New Issue