forked from mindspore-Ecosystem/mindspore
[feat] [assistant] [I40GZC] add new data ops RandomInvert
This commit is contained in:
parent
4141404213
commit
b977d8723b
|
@ -44,6 +44,7 @@
|
||||||
#include "minddata/dataset/kernels/ir/vision/random_crop_with_bbox_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/random_crop_with_bbox_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_with_bbox_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_with_bbox_ir.h"
|
||||||
|
#include "minddata/dataset/kernels/ir/vision/random_invert_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/random_posterize_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/random_posterize_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_with_bbox_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_with_bbox_ir.h"
|
||||||
|
@ -373,6 +374,17 @@ PYBIND_REGISTER(RandomHorizontalFlipWithBBoxOperation, 1, ([](const py::module *
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(
|
||||||
|
RandomInvertOperation, 1, ([](const py::module *m) {
|
||||||
|
(void)py::class_<vision::RandomInvertOperation, TensorOperation, std::shared_ptr<vision::RandomInvertOperation>>(
|
||||||
|
*m, "RandomInvertOperation")
|
||||||
|
.def(py::init([](float prob) {
|
||||||
|
auto random_invert = std::make_shared<vision::RandomInvertOperation>(prob);
|
||||||
|
THROW_IF_ERROR(random_invert->ValidateParams());
|
||||||
|
return random_invert;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(RandomPosterizeOperation, 1, ([](const py::module *m) {
|
PYBIND_REGISTER(RandomPosterizeOperation, 1, ([](const py::module *m) {
|
||||||
(void)py::class_<vision::RandomPosterizeOperation, TensorOperation,
|
(void)py::class_<vision::RandomPosterizeOperation, TensorOperation,
|
||||||
std::shared_ptr<vision::RandomPosterizeOperation>>(*m, "RandomPosterizeOperation")
|
std::shared_ptr<vision::RandomPosterizeOperation>>(*m, "RandomPosterizeOperation")
|
||||||
|
|
|
@ -48,6 +48,7 @@
|
||||||
#include "minddata/dataset/kernels/ir/vision/random_crop_with_bbox_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/random_crop_with_bbox_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_with_bbox_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/random_horizontal_flip_with_bbox_ir.h"
|
||||||
|
#include "minddata/dataset/kernels/ir/vision/random_invert_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/random_posterize_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/random_posterize_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_ir.h"
|
||||||
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_with_bbox_ir.h"
|
#include "minddata/dataset/kernels/ir/vision/random_resized_crop_with_bbox_ir.h"
|
||||||
|
@ -608,6 +609,18 @@ std::shared_ptr<TensorOperation> RandomHorizontalFlipWithBBox::Parse() {
|
||||||
return std::make_shared<RandomHorizontalFlipWithBBoxOperation>(data_->probability_);
|
return std::make_shared<RandomHorizontalFlipWithBBoxOperation>(data_->probability_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RandomInvert Operation.
|
||||||
|
struct RandomInvert::Data {
|
||||||
|
explicit Data(float prob) : probability_(prob) {}
|
||||||
|
float probability_;
|
||||||
|
};
|
||||||
|
|
||||||
|
RandomInvert::RandomInvert(float prob) : data_(std::make_shared<Data>(prob)) {}
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOperation> RandomInvert::Parse() {
|
||||||
|
return std::make_shared<RandomInvertOperation>(data_->probability_);
|
||||||
|
}
|
||||||
|
|
||||||
// RandomPosterize Transform Operation.
|
// RandomPosterize Transform Operation.
|
||||||
struct RandomPosterize::Data {
|
struct RandomPosterize::Data {
|
||||||
explicit Data(const std::vector<uint8_t> &bit_range) : bit_range_(bit_range) {}
|
explicit Data(const std::vector<uint8_t> &bit_range) : bit_range_(bit_range) {}
|
||||||
|
|
|
@ -535,6 +535,27 @@ class RandomHorizontalFlipWithBBox final : public TensorTransform {
|
||||||
std::shared_ptr<Data> data_;
|
std::shared_ptr<Data> data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// \brief Randomly invert the input image with a given probability.
|
||||||
|
class RandomInvert final : public TensorTransform {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor.
|
||||||
|
/// \param[in] prob A float representing the probability of the image being inverted, which
|
||||||
|
/// must be in range of [0, 1] (default=0.5).
|
||||||
|
explicit RandomInvert(float prob = 0.5);
|
||||||
|
|
||||||
|
/// \brief Destructor.
|
||||||
|
~RandomInvert() = default;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
|
||||||
|
/// \return Shared pointer to TensorOperation object.
|
||||||
|
std::shared_ptr<TensorOperation> Parse() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct Data;
|
||||||
|
std::shared_ptr<Data> data_;
|
||||||
|
};
|
||||||
|
|
||||||
/// \brief Reduce the number of bits for each color channel randomly.
|
/// \brief Reduce the number of bits for each color channel randomly.
|
||||||
class RandomPosterize final : public TensorTransform {
|
class RandomPosterize final : public TensorTransform {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -37,6 +37,7 @@ add_library(kernels-image OBJECT
|
||||||
random_crop_with_bbox_op.cc
|
random_crop_with_bbox_op.cc
|
||||||
random_horizontal_flip_op.cc
|
random_horizontal_flip_op.cc
|
||||||
random_horizontal_flip_with_bbox_op.cc
|
random_horizontal_flip_with_bbox_op.cc
|
||||||
|
random_invert_op.cc
|
||||||
bounding_box_augment_op.cc
|
bounding_box_augment_op.cc
|
||||||
random_posterize_op.cc
|
random_posterize_op.cc
|
||||||
random_resize_op.cc
|
random_resize_op.cc
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/kernels/image/random_invert_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
const float RandomInvertOp::kDefProbability = 0.5;
|
||||||
|
|
||||||
|
Status RandomInvertOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
IO_CHECK(input, output);
|
||||||
|
if (distribution_(rnd_)) {
|
||||||
|
return InvertOp::Compute(input, output);
|
||||||
|
}
|
||||||
|
*output = input;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_INVERT_OP_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_INVERT_OP_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
#include "minddata/dataset/kernels/image/invert_op.h"
|
||||||
|
#include "minddata/dataset/util/random.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
class RandomInvertOp : public InvertOp {
|
||||||
|
public:
|
||||||
|
static const float kDefProbability;
|
||||||
|
|
||||||
|
explicit RandomInvertOp(float prob = kDefProbability) : distribution_(prob) {
|
||||||
|
is_deterministic_ = false;
|
||||||
|
rnd_.seed(GetSeed());
|
||||||
|
}
|
||||||
|
|
||||||
|
~RandomInvertOp() override = default;
|
||||||
|
|
||||||
|
// Provide stream operator for displaying it
|
||||||
|
friend std::ostream &operator<<(std::ostream &out, const RandomInvertOp &so) {
|
||||||
|
so.Print(out);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||||
|
|
||||||
|
std::string Name() const override { return kRandomInvertOp; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mt19937 rnd_;
|
||||||
|
std::bernoulli_distribution distribution_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_INVERT_OP_H_
|
|
@ -29,6 +29,7 @@ set(DATASET_KERNELS_IR_VISION_SRC_FILES
|
||||||
random_crop_with_bbox_ir.cc
|
random_crop_with_bbox_ir.cc
|
||||||
random_horizontal_flip_ir.cc
|
random_horizontal_flip_ir.cc
|
||||||
random_horizontal_flip_with_bbox_ir.cc
|
random_horizontal_flip_with_bbox_ir.cc
|
||||||
|
random_invert_ir.cc
|
||||||
random_posterize_ir.cc
|
random_posterize_ir.cc
|
||||||
random_resized_crop_ir.cc
|
random_resized_crop_ir.cc
|
||||||
random_resized_crop_with_bbox_ir.cc
|
random_resized_crop_with_bbox_ir.cc
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/kernels/ir/vision/random_invert_ir.h"
|
||||||
|
|
||||||
|
#ifndef ENABLE_ANDROID
|
||||||
|
#include "minddata/dataset/kernels/image/random_invert_op.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "minddata/dataset/kernels/ir/validators.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace vision {
|
||||||
|
#ifndef ENABLE_ANDROID
|
||||||
|
// RandomInvertOperation
|
||||||
|
RandomInvertOperation::RandomInvertOperation(float prob) : TensorOperation(true), probability_(prob) {}
|
||||||
|
|
||||||
|
RandomInvertOperation::~RandomInvertOperation() = default;
|
||||||
|
|
||||||
|
std::string RandomInvertOperation::Name() const { return kRandomInvertOperation; }
|
||||||
|
|
||||||
|
Status RandomInvertOperation::ValidateParams() {
|
||||||
|
RETURN_IF_NOT_OK(ValidateProbability("RandomInvert", probability_));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOp> RandomInvertOperation::Build() {
|
||||||
|
std::shared_ptr<RandomInvertOp> tensor_op = std::make_shared<RandomInvertOp>(probability_);
|
||||||
|
return tensor_op;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RandomInvertOperation::to_json(nlohmann::json *out_json) {
|
||||||
|
(*out_json)["prob"] = probability_;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RandomInvertOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("prob") != op_params.end(), "Failed to find prob");
|
||||||
|
float prob = op_params["prob"];
|
||||||
|
*operation = std::make_shared<vision::RandomInvertOperation>(prob);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_INVERT_IR_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_INVERT_IR_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "include/api/status.h"
|
||||||
|
#include "minddata/dataset/include/dataset/constants.h"
|
||||||
|
#include "minddata/dataset/include/dataset/transforms.h"
|
||||||
|
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
namespace vision {
|
||||||
|
|
||||||
|
constexpr char kRandomInvertOperation[] = "RandomInvert";
|
||||||
|
|
||||||
|
class RandomInvertOperation : public TensorOperation {
|
||||||
|
public:
|
||||||
|
explicit RandomInvertOperation(float prob);
|
||||||
|
|
||||||
|
~RandomInvertOperation();
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOp> Build() override;
|
||||||
|
|
||||||
|
Status ValidateParams() override;
|
||||||
|
|
||||||
|
std::string Name() const override;
|
||||||
|
|
||||||
|
Status to_json(nlohmann::json *out_json) override;
|
||||||
|
|
||||||
|
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
||||||
|
|
||||||
|
private:
|
||||||
|
float probability_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_RANDOM_INVERT_IR_H_
|
|
@ -89,6 +89,7 @@ constexpr char kRandomCropOp[] = "RandomCropOp";
|
||||||
constexpr char kRandomCropWithBBoxOp[] = "RandomCropWithBBoxOp";
|
constexpr char kRandomCropWithBBoxOp[] = "RandomCropWithBBoxOp";
|
||||||
constexpr char kRandomHorizontalFlipWithBBoxOp[] = "RandomHorizontalFlipWithBBoxOp";
|
constexpr char kRandomHorizontalFlipWithBBoxOp[] = "RandomHorizontalFlipWithBBoxOp";
|
||||||
constexpr char kRandomHorizontalFlipOp[] = "RandomHorizontalFlipOp";
|
constexpr char kRandomHorizontalFlipOp[] = "RandomHorizontalFlipOp";
|
||||||
|
constexpr char kRandomInvertOp[] = "RandomInvertOp";
|
||||||
constexpr char kRandomResizeOp[] = "RandomResizeOp";
|
constexpr char kRandomResizeOp[] = "RandomResizeOp";
|
||||||
constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp";
|
constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp";
|
||||||
constexpr char kRandomRotationOp[] = "RandomRotationOp";
|
constexpr char kRandomRotationOp[] = "RandomRotationOp";
|
||||||
|
|
|
@ -1092,6 +1092,27 @@ class RandomHorizontalFlipWithBBox(ImageTensorOperation):
|
||||||
return cde.RandomHorizontalFlipWithBBoxOperation(self.prob)
|
return cde.RandomHorizontalFlipWithBBoxOperation(self.prob)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomInvert(ImageTensorOperation):
|
||||||
|
"""
|
||||||
|
Randomly invert the colors of image with a given probability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prob (float, optional): Probability of the image being inverted, which must be in range of [0, 1] (default=0.5).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> transforms_list = [c_vision.Decode(), c_vision.RandomInvert(0.5)]
|
||||||
|
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
|
||||||
|
... input_columns=["image"])
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_prob
|
||||||
|
def __init__(self, prob=0.5):
|
||||||
|
self.prob = prob
|
||||||
|
|
||||||
|
def parse(self):
|
||||||
|
return cde.RandomInvertOperation(self.prob)
|
||||||
|
|
||||||
|
|
||||||
class RandomPosterize(ImageTensorOperation):
|
class RandomPosterize(ImageTensorOperation):
|
||||||
"""
|
"""
|
||||||
Reduce the number of bits for each color channel to posterize the input image randomly with a given probability.
|
Reduce the number of bits for each color channel to posterize the input image randomly with a given probability.
|
||||||
|
|
|
@ -1270,3 +1270,49 @@ TEST_F(MindDataTestPipeline, TestConvertColorFail) {
|
||||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
EXPECT_EQ(iter, nullptr);
|
EXPECT_EQ(iter, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestPipeline, TestRandomInvert) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomInvert.";
|
||||||
|
|
||||||
|
std::string MindDataPath = "data/dataset";
|
||||||
|
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||||
|
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
auto random_invert_op = vision::RandomInvert(0.5);
|
||||||
|
|
||||||
|
ds = ds->Map({random_invert_op});
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
i++;
|
||||||
|
auto image = row["image"];
|
||||||
|
iter->GetNextRow(&row);
|
||||||
|
}
|
||||||
|
EXPECT_EQ(i, 2);
|
||||||
|
|
||||||
|
iter->Stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestPipeline, TestRandomInvertInvalidProb) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomInvertInvalidProb.";
|
||||||
|
|
||||||
|
std::string MindDataPath = "data/dataset";
|
||||||
|
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||||
|
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
auto random_invert_op = vision::RandomInvert(1.5);
|
||||||
|
|
||||||
|
ds = ds->Map({random_invert_op});
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_EQ(iter, nullptr);
|
||||||
|
}
|
||||||
|
|
|
@ -1057,6 +1057,7 @@ TEST_F(MindDataTestExecute, TestFadeWithInvalidArg) {
|
||||||
Status s04 = Transform04(input_04, &input_04);
|
Status s04 = Transform04(input_04, &input_04);
|
||||||
EXPECT_FALSE(s04.IsOk());
|
EXPECT_FALSE(s04.IsOk());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MindDataTestExecute, TestVolDefalutValue) {
|
TEST_F(MindDataTestExecute, TestVolDefalutValue) {
|
||||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestVolDefalutValue.";
|
MS_LOG(INFO) << "Doing MindDataTestExecute-TestVolDefalutValue.";
|
||||||
std::shared_ptr<Tensor> input_tensor_;
|
std::shared_ptr<Tensor> input_tensor_;
|
||||||
|
@ -1097,3 +1098,17 @@ TEST_F(MindDataTestExecute, TestMagphaseEager) {
|
||||||
Status rc = transform({input_tensor}, &output_tensor);
|
Status rc = transform({input_tensor}, &output_tensor);
|
||||||
ASSERT_TRUE(rc.IsOk());
|
ASSERT_TRUE(rc.IsOk());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestExecute, TestRandomInvertEager) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestExecute-TestRandomInvertEager.";
|
||||||
|
// Read images
|
||||||
|
auto image = ReadFileToTensor("data/dataset/apple.jpg");
|
||||||
|
|
||||||
|
// Transform params
|
||||||
|
auto decode = vision::Decode();
|
||||||
|
auto random_invert_op = vision::RandomInvert(0.6);
|
||||||
|
|
||||||
|
auto transform = Execute({decode, random_invert_op});
|
||||||
|
Status rc = transform(image, &image);
|
||||||
|
EXPECT_EQ(rc, Status::OK());
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,129 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Testing RandomInvert in DE
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
from mindspore.dataset.vision.c_transforms import Decode, Resize, RandomInvert, Invert
|
||||||
|
from mindspore import log as logger
|
||||||
|
from util import visualize_list, visualize_image, diff_mse
|
||||||
|
|
||||||
|
image_file = "../data/dataset/testImageNetData/train/class1/1_1.jpg"
|
||||||
|
data_dir = "../data/dataset/testImageNetData/train/"
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_invert_pipeline(plot=False):
|
||||||
|
"""
|
||||||
|
Test RandomInvert pipeline
|
||||||
|
"""
|
||||||
|
logger.info("Test RandomInvert pipeline")
|
||||||
|
|
||||||
|
# Original Images
|
||||||
|
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||||
|
transforms_original = [Decode(), Resize(size=[224, 224])]
|
||||||
|
ds_original = data_set.map(operations=transforms_original, input_columns="image")
|
||||||
|
ds_original = ds_original.batch(512)
|
||||||
|
|
||||||
|
for idx, (image, _) in enumerate(ds_original):
|
||||||
|
if idx == 0:
|
||||||
|
images_original = image.asnumpy()
|
||||||
|
else:
|
||||||
|
images_original = np.append(images_original,
|
||||||
|
image.asnumpy(),
|
||||||
|
axis=0)
|
||||||
|
|
||||||
|
# Randomly Inverted Images
|
||||||
|
data_set1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||||
|
transform_random_invert = [Decode(), Resize(size=[224, 224]), RandomInvert(0.6)]
|
||||||
|
ds_random_invert = data_set1.map(operations=transform_random_invert, input_columns="image")
|
||||||
|
ds_random_invert = ds_random_invert.batch(512)
|
||||||
|
for idx, (image, _) in enumerate(ds_random_invert):
|
||||||
|
if idx == 0:
|
||||||
|
images_random_invert = image.asnumpy()
|
||||||
|
else:
|
||||||
|
images_random_invert = np.append(images_random_invert,
|
||||||
|
image.asnumpy(),
|
||||||
|
axis=0)
|
||||||
|
if plot:
|
||||||
|
visualize_list(images_original, images_random_invert)
|
||||||
|
|
||||||
|
num_samples = images_original.shape[0]
|
||||||
|
mse = np.zeros(num_samples)
|
||||||
|
for i in range(num_samples):
|
||||||
|
mse[i] = diff_mse(images_random_invert[i], images_original[i])
|
||||||
|
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_invert_eager():
|
||||||
|
"""
|
||||||
|
Test RandomInvert eager.
|
||||||
|
"""
|
||||||
|
img = np.fromfile(image_file, dtype=np.uint8)
|
||||||
|
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||||
|
|
||||||
|
img = Decode()(img)
|
||||||
|
img_inverted = Invert()(img)
|
||||||
|
img_random_inverted = RandomInvert(1.0)(img)
|
||||||
|
logger.info("Image.type: {}, Image.shape: {}".format(type(img_random_inverted), img_random_inverted.shape))
|
||||||
|
|
||||||
|
assert img_random_inverted.all() == img_inverted.all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_invert_comp(plot=False):
|
||||||
|
"""
|
||||||
|
Test RandomInvert op compared with Invert op.
|
||||||
|
"""
|
||||||
|
random_invert_op = RandomInvert(prob=1.0)
|
||||||
|
invert_op = Invert()
|
||||||
|
|
||||||
|
dataset1 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
|
||||||
|
for item in dataset1.create_dict_iterator(output_numpy=True):
|
||||||
|
image = item['image']
|
||||||
|
dataset1.map(operations=random_invert_op, input_columns=['image'])
|
||||||
|
dataset2 = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
|
||||||
|
dataset2.map(operations=invert_op, input_columns=['image'])
|
||||||
|
for item1, item2 in zip(dataset1.create_dict_iterator(output_numpy=True),
|
||||||
|
dataset2.create_dict_iterator(output_numpy=True)):
|
||||||
|
image_random_inverted = item1['image']
|
||||||
|
image_inverted = item2['image']
|
||||||
|
|
||||||
|
mse = diff_mse(image_inverted, image_random_inverted)
|
||||||
|
assert mse == 0
|
||||||
|
logger.info("mse: {}".format(mse))
|
||||||
|
if plot:
|
||||||
|
visualize_image(image, image_random_inverted, mse, image_inverted)
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_invert_invalid_prob():
|
||||||
|
"""
|
||||||
|
Test invalid prob. prob out of range.
|
||||||
|
"""
|
||||||
|
logger.info("test_random_invert_invalid_prob")
|
||||||
|
dataset = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
|
||||||
|
try:
|
||||||
|
random_invert_op = RandomInvert(1.5)
|
||||||
|
dataset = dataset.map(operations=random_invert_op, input_columns=['image'])
|
||||||
|
except ValueError as e:
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
|
assert "Input prob is not within the required interval of [0.0, 1.0]." in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_random_invert_pipeline(plot=True)
|
||||||
|
test_random_invert_eager()
|
||||||
|
test_random_invert_comp(plot=True)
|
||||||
|
test_random_invert_invalid_prob()
|
Loading…
Reference in New Issue