forked from mindspore-Ecosystem/mindspore
!24547 [MS][crowdfunding]New operator implementation, AutoAugment
Merge pull request !24547 from yangwm/autoaugment
This commit is contained in:
commit
3c39afad11
|
@ -104,6 +104,14 @@ PYBIND_REGISTER(DataType, 0, ([](const py::module *m) {
|
|||
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(AutoAugmentPolicy, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<AutoAugmentPolicy>(*m, "AutoAugmentPolicy", py::arithmetic())
|
||||
.value("DE_AUTO_AUGMENT_POLICY_IMAGENET", AutoAugmentPolicy::kImageNet)
|
||||
.value("DE_AUTO_AUGMENT_POLICY_CIFAR10", AutoAugmentPolicy::kCifar10)
|
||||
.value("DE_AUTO_AUGMENT_POLICY_SVHN", AutoAugmentPolicy::kSVHN)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(BorderType, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<BorderType>(*m, "BorderType", py::arithmetic())
|
||||
.value("DE_BORDER_CONSTANT", BorderType::kConstant)
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/auto_augment_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/auto_contrast_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/center_crop_ir.h"
|
||||
|
@ -73,7 +74,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(
|
||||
AdjustGammaOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<vision::AdjustGammaOperation, TensorOperation, std::shared_ptr<vision::AdjustGammaOperation>>(
|
||||
|
@ -85,6 +85,18 @@ PYBIND_REGISTER(
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
AutoAugmentOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<vision::AutoAugmentOperation, TensorOperation, std::shared_ptr<vision::AutoAugmentOperation>>(
|
||||
*m, "AutoAugmentOperation")
|
||||
.def(
|
||||
py::init([](AutoAugmentPolicy policy, InterpolationMode interpolation, const std::vector<uint8_t> &fill_value) {
|
||||
auto auto_augment = std::make_shared<vision::AutoAugmentOperation>(policy, interpolation, fill_value);
|
||||
THROW_IF_ERROR(auto_augment->ValidateParams());
|
||||
return auto_augment;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
AutoContrastOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<vision::AutoContrastOperation, TensorOperation, std::shared_ptr<vision::AutoContrastOperation>>(
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/affine_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/auto_augment_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/auto_contrast_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/center_crop_ir.h"
|
||||
|
@ -138,6 +139,23 @@ std::shared_ptr<TensorOperation> AdjustGamma::Parse() {
|
|||
return std::make_shared<AdjustGammaOperation>(data_->gamma_, data_->gain_);
|
||||
}
|
||||
|
||||
// AutoAugment Transform Operation.
|
||||
struct AutoAugment::Data {
|
||||
Data(AutoAugmentPolicy policy, InterpolationMode interpolation, const std::vector<uint8_t> &fill_value)
|
||||
: policy_(policy), interpolation_(interpolation), fill_value_(fill_value) {}
|
||||
AutoAugmentPolicy policy_;
|
||||
InterpolationMode interpolation_;
|
||||
std::vector<uint8_t> fill_value_;
|
||||
};
|
||||
|
||||
AutoAugment::AutoAugment(AutoAugmentPolicy policy, InterpolationMode interpolation,
|
||||
const std::vector<uint8_t> &fill_value)
|
||||
: data_(std::make_shared<Data>(policy, interpolation, fill_value)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> AutoAugment::Parse() {
|
||||
return std::make_shared<AutoAugmentOperation>(data_->policy_, data_->interpolation_, data_->fill_value_);
|
||||
}
|
||||
|
||||
// AutoContrast Transform Operation.
|
||||
struct AutoContrast::Data {
|
||||
Data(float cutoff, const std::vector<uint32_t> &ignore) : cutoff_(cutoff), ignore_(ignore) {}
|
||||
|
|
|
@ -38,6 +38,13 @@ enum class Interpolation {
|
|||
kQuadratic = 1 ///< Use quadratic for delay-line interpolation.
|
||||
};
|
||||
|
||||
/// \brief The dataset auto augment policy in AutoAugment
|
||||
enum class AutoAugmentPolicy {
|
||||
kImageNet = 0, ///< AutoAugment policy learned on the ImageNet dataset.
|
||||
kCifar10 = 1, ///< AutoAugment policy learned on the Cifar10 dataset.
|
||||
kSVHN = 2 ///< AutoAugment policy learned on the SVHN dataset.
|
||||
};
|
||||
|
||||
/// \brief The color conversion code
|
||||
enum class ConvertMode {
|
||||
COLOR_BGR2BGRA = 0, ///< Add alpha channel to BGR image.
|
||||
|
|
|
@ -70,6 +70,46 @@ class AdjustGamma final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Apply AutoAugment data augmentation method.
|
||||
class AutoAugment final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] policy An enum for the data auto augmentation policy (default=AutoAugmentPolicy::kImageNet).
|
||||
/// - AutoAugmentPolicy::kIMAGENET, AutoAugment policy learned on the ImageNet dataset.
|
||||
/// - AutoAugmentPolicy::kCIFAR10, AutoAugment policy learned on the Cifar10 dataset.
|
||||
/// - AutoAugmentPolicy::kSVHN, AutoAugment policy learned on the SVHN dataset.
|
||||
/// \param[in] interpolation An enum for the mode of interpolation (default=InterpolationMode::kNearestNeighbour).
|
||||
/// - InterpolationMode::kLinear, Interpolation method is blinear interpolation.
|
||||
/// - InterpolationMode::kNearestNeighbour, Interpolation method is nearest-neighbor interpolation.
|
||||
/// - InterpolationMode::kCubic, Interpolation method is bicubic interpolation.
|
||||
/// \param[in] fill_value A vector representing the pixel intensity of the borders (default={0, 0, 0}).
|
||||
/// \par Example
|
||||
/// \code
|
||||
/// /* Define operations */
|
||||
/// auto decode_op = vision::Decode();
|
||||
/// auto auto_augment_op = vision::AutoAugment(AutoAugmentPolicy::kImageNet,
|
||||
/// InterpolationMode::kNearestNeighbour, {0, 0, 0});
|
||||
/// /* dataset is an instance of Dataset object */
|
||||
/// dataset = dataset->Map({decode_op, auto_augment_op}, // operations
|
||||
/// {"image"}); // input columns
|
||||
/// \endcode
|
||||
AutoAugment(AutoAugmentPolicy policy = AutoAugmentPolicy::kImageNet,
|
||||
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
|
||||
const std::vector<uint8_t> &fill_value = {0, 0, 0});
|
||||
|
||||
/// \brief Destructor.
|
||||
~AutoAugment() = default;
|
||||
|
||||
protected:
|
||||
/// \brief The function to convert a TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Apply automatic contrast on the input image.
|
||||
class AutoContrast final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -8,6 +8,7 @@ endif()
|
|||
add_library(kernels-image OBJECT
|
||||
adjust_gamma_op.cc
|
||||
affine_op.cc
|
||||
auto_augment_op.cc
|
||||
auto_contrast_op.cc
|
||||
bounding_box.cc
|
||||
center_crop_op.cc
|
||||
|
|
|
@ -0,0 +1,219 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/kernels/image/auto_augment_op.h"
|
||||
|
||||
#include "minddata/dataset/kernels/image/affine_op.h"
|
||||
#include "minddata/dataset/kernels/image/auto_contrast_op.h"
|
||||
#include "minddata/dataset/kernels/image/invert_op.h"
|
||||
#include "minddata/dataset/kernels/image/posterize_op.h"
|
||||
#include "minddata/dataset/kernels/image/sharpness_op.h"
|
||||
#include "minddata/dataset/kernels/image/solarize_op.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
AutoAugmentOp::AutoAugmentOp(AutoAugmentPolicy policy, InterpolationMode interpolation,
|
||||
const std::vector<uint8_t> &fill_value)
|
||||
: policy_(policy), interpolation_(interpolation), fill_value_(fill_value) {
|
||||
rnd_.seed(GetSeed());
|
||||
transforms_ = GetTransforms(policy);
|
||||
}
|
||||
|
||||
Transforms AutoAugmentOp::GetTransforms(AutoAugmentPolicy policy) {
|
||||
if (policy == AutoAugmentPolicy::kImageNet) {
|
||||
return {{{"Posterize", 0.4, 8}, {"Rotate", 0.6, 9}}, {{"Solarize", 0.6, 5}, {"AutoContrast", 0.6, -1}},
|
||||
{{"Equalize", 0.8, -1}, {"Equalize", 0.6, -1}}, {{"Posterize", 0.6, 7}, {"Posterize", 0.6, 6}},
|
||||
{{"Equalize", 0.4, -1}, {"Solarize", 0.2, 4}}, {{"Equalize", 0.4, -1}, {"Rotate", 0.8, 8}},
|
||||
{{"Solarize", 0.6, 3}, {"Equalize", 0.6, -1}}, {{"Posterize", 0.8, 5}, {"Equalize", 1.0, -1}},
|
||||
{{"Rotate", 0.2, 3}, {"Solarize", 0.6, 8}}, {{"Equalize", 0.6, -1}, {"Posterize", 0.4, 6}},
|
||||
{{"Rotate", 0.8, 8}, {"Color", 0.4, 0}}, {{"Rotate", 0.4, 9}, {"Equalize", 0.6, -1}},
|
||||
{{"Equalize", 0.0, -1}, {"Equalize", 0.8, -1}}, {{"Invert", 0.6, -1}, {"Equalize", 1.0, -1}},
|
||||
{{"Color", 0.6, 4}, {"Contrast", 1.0, 8}}, {{"Rotate", 0.8, 8}, {"Color", 1.0, 2}},
|
||||
{{"Color", 0.8, 8}, {"Solarize", 0.8, 7}}, {{"Sharpness", 0.4, 7}, {"Invert", 0.6, -1}},
|
||||
{{"ShearX", 0.6, 5}, {"Equalize", 1.0, -1}}, {{"Color", 0.4, 0}, {"Equalize", 0.6, -1}},
|
||||
{{"Equalize", 0.4, -1}, {"Solarize", 0.2, 4}}, {{"Solarize", 0.6, 5}, {"AutoContrast", 0.6, -1}},
|
||||
{{"Invert", 0.6, -1}, {"Equalize", 1.0, -1}}, {{"Color", 0.6, 4}, {"Contrast", 1.0, 8}},
|
||||
{{"Equalize", 0.8, -1}, {"Equalize", 0.6, -1}}};
|
||||
} else if (policy == AutoAugmentPolicy::kCifar10) {
|
||||
return {{{"Invert", 0.1, -1}, {"Contrast", 0.2, 6}}, {{"Rotate", 0.7, 2}, {"TranslateX", 0.3, 9}},
|
||||
{{"Sharpness", 0.8, 1}, {"Sharpness", 0.9, 3}}, {{"ShearY", 0.5, 8}, {"TranslateY", 0.7, 9}},
|
||||
{{"AutoContrast", 0.5, -1}, {"Equalize", 0.9, -1}}, {{"ShearY", 0.2, 7}, {"Posterize", 0.3, 7}},
|
||||
{{"Color", 0.4, 3}, {"Brightness", 0.6, 7}}, {{"Sharpness", 0.3, 9}, {"Brightness", 0.7, 9}},
|
||||
{{"Equalize", 0.6, -1}, {"Equalize", 0.5, -1}}, {{"Contrast", 0.6, 7}, {"Sharpness", 0.6, 5}},
|
||||
{{"Color", 0.7, 7}, {"TranslateX", 0.5, 8}}, {{"Equalize", 0.3, -1}, {"AutoContrast", 0.4, -1}},
|
||||
{{"TranslateY", 0.4, 3}, {"Sharpness", 0.2, 6}}, {{"Brightness", 0.9, 6}, {"Color", 0.2, 8}},
|
||||
{{"Solarize", 0.5, 2}, {"Invert", 0.0, -1}}, {{"Equalize", 0.2, -1}, {"AutoContrast", 0.6, -1}},
|
||||
{{"Equalize", 0.2, -1}, {"Equalize", 0.6, -1}}, {{"Color", 0.9, 9}, {"Equalize", 0.6, -1}},
|
||||
{{"AutoContrast", 0.8, -1}, {"Solarize", 0.2, 8}}, {{"Brightness", 0.1, 3}, {"Color", 0.7, 0}},
|
||||
{{"Solarize", 0.4, 5}, {"AutoContrast", 0.9, -1}}, {{"TranslateY", 0.9, 9}, {"TranslateY", 0.7, 9}},
|
||||
{{"AutoContrast", 0.9, -1}, {"Solarize", 0.8, 3}}, {{"Equalize", 0.8, -1}, {"Invert", 0.1, -1}},
|
||||
{{"TranslateY", 0.7, 9}, {"AutoContrast", 0.9, -1}}};
|
||||
} else {
|
||||
return {{{"ShearX", 0.9, 4}, {"Invert", 0.2, -1}}, {{"ShearY", 0.9, 8}, {"Invert", 0.7, -1}},
|
||||
{{"Equalize", 0.6, -1}, {"Solarize", 0.6, 6}}, {{"Invert", 0.9, -1}, {"Equalize", 0.6, -1}},
|
||||
{{"Equalize", 0.6, -1}, {"Rotate", 0.9, 3}}, {{"ShearX", 0.9, 4}, {"AutoContrast", 0.8, -1}},
|
||||
{{"ShearY", 0.9, 8}, {"Invert", 0.4, -1}}, {{"ShearY", 0.9, 5}, {"Solarize", 0.2, 6}},
|
||||
{{"Invert", 0.9, -1}, {"AutoContrast", 0.8, -1}}, {{"Equalize", 0.6, -1}, {"Rotate", 0.9, 3}},
|
||||
{{"ShearX", 0.9, 4}, {"Solarize", 0.3, 3}}, {{"ShearY", 0.8, 8}, {"Invert", 0.7, -1}},
|
||||
{{"Equalize", 0.9, -1}, {"TranslateY", 0.6, 6}}, {{"Invert", 0.9, -1}, {"Equalize", 0.6, -1}},
|
||||
{{"Contrast", 0.3, 3}, {"Rotate", 0.8, 4}}, {{"Invert", 0.8, -1}, {"TranslateY", 0.0, 2}},
|
||||
{{"ShearY", 0.7, 6}, {"Solarize", 0.4, 8}}, {{"Invert", 0.6, -1}, {"Rotate", 0.8, 4}},
|
||||
{{"ShearY", 0.3, 7}, {"TranslateX", 0.9, 3}}, {{"ShearX", 0.1, 6}, {"Invert", 0.6, -1}},
|
||||
{{"Solarize", 0.7, 2}, {"TranslateY", 0.6, 7}}, {{"ShearY", 0.8, 4}, {"Invert", 0.8, -1}},
|
||||
{{"ShearX", 0.7, 9}, {"TranslateY", 0.8, 3}}, {{"ShearY", 0.8, 5}, {"AutoContrast", 0.7, -1}},
|
||||
{{"ShearX", 0.7, 2}, {"Invert", 0.1, -1}}};
|
||||
}
|
||||
}
|
||||
|
||||
Status AutoAugmentOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
if (input->Rank() != DEFAULT_IMAGE_RANK) {
|
||||
RETURN_STATUS_UNEXPECTED("AutoAugment: input tensor is not in shape of <H,W,C>, but got rank: " +
|
||||
std::to_string(input->Rank()));
|
||||
}
|
||||
int num_channels = input->shape()[2];
|
||||
if (num_channels != DEFAULT_IMAGE_CHANNELS) {
|
||||
RETURN_STATUS_UNEXPECTED("AutoAugment: channel of input image should be 3, but got: " +
|
||||
std::to_string(num_channels));
|
||||
}
|
||||
|
||||
int transform_id;
|
||||
std::vector<float> *probs = new std::vector<float>{0, 0};
|
||||
std::vector<int32_t> *signs = new std::vector<int32_t>{0, 0};
|
||||
GetParams(transforms_.size(), &transform_id, probs, signs);
|
||||
|
||||
std::vector<dsize_t> image_size = {input->shape()[0], input->shape()[1]};
|
||||
std::shared_ptr<Tensor> img = input;
|
||||
|
||||
const int num_augments = 2;
|
||||
for (auto i = 0; i < num_augments; i++) {
|
||||
std::string op_name = std::get<0>(transforms_[transform_id][i]);
|
||||
float p = std::get<1>(transforms_[transform_id][i]);
|
||||
int32_t magnitude_id = std::get<2>(transforms_[transform_id][i]);
|
||||
if ((*probs)[i] <= p) {
|
||||
Space space = GetSpace(10, image_size);
|
||||
std::vector<float> magnitudes = std::get<0>(space[op_name]);
|
||||
bool sign = std::get<1>(space[op_name]);
|
||||
float magnitude = 0.0;
|
||||
if (magnitudes.size() != 1 && magnitude_id != -1) {
|
||||
magnitude = magnitudes[magnitude_id];
|
||||
}
|
||||
if (sign && (*signs)[i] == 0) {
|
||||
magnitude *= -1.0;
|
||||
}
|
||||
RETURN_IF_NOT_OK(ApplyAugment(img, &img, op_name, magnitude));
|
||||
}
|
||||
}
|
||||
*output = img;
|
||||
delete probs;
|
||||
delete signs;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void AutoAugmentOp::GetParams(int transform_num, int *transform_id, std::vector<float> *probs,
|
||||
std::vector<int32_t> *signs) {
|
||||
std::uniform_int_distribution<int32_t> id_dist(0, transform_num - 1);
|
||||
*transform_id = id_dist(rnd_);
|
||||
std::uniform_real_distribution<float> prob_dist(0, 1);
|
||||
|
||||
(*probs)[0] = prob_dist(rnd_);
|
||||
(*probs)[1] = prob_dist(rnd_);
|
||||
|
||||
std::uniform_int_distribution<int32_t> sign_dist(0, 1);
|
||||
|
||||
(*signs)[0] = sign_dist(rnd_);
|
||||
(*signs)[1] = sign_dist(rnd_);
|
||||
}
|
||||
|
||||
std::vector<float> Linspace(float start, float end, int n, float scale = 1.0, float offset = 0) {
|
||||
std::vector<float> linear(n);
|
||||
float step = (n == 1) ? 0 : ((end - start) / (n - 1));
|
||||
for (auto i = 0; i < linear.size(); ++i) {
|
||||
linear[i] = (start + i * step) * scale + offset;
|
||||
}
|
||||
return linear;
|
||||
}
|
||||
|
||||
Space AutoAugmentOp::GetSpace(int32_t num_bins, const std::vector<dsize_t> &image_size) {
|
||||
Space space = {{"ShearX", {Linspace(0.0, 0.3, num_bins), true}},
|
||||
{"ShearY", {Linspace(0.0, 0.3, num_bins), true}},
|
||||
{"TranslateX", {Linspace(0.0, 150.0 / 331 * image_size[1], num_bins), true}},
|
||||
{"TranslateY", {Linspace(0.0, 150.0 / 331 * image_size[0], num_bins), true}},
|
||||
{"Rotate", {Linspace(0.0, 30, num_bins), true}},
|
||||
{"Brightness", {Linspace(0.0, 0.9, num_bins), true}},
|
||||
{"Color", {Linspace(0.0, 0.9, num_bins), true}},
|
||||
{"Contrast", {Linspace(0.0, 0.9, num_bins), true}},
|
||||
{"Sharpness", {Linspace(0.0, 0.9, num_bins), true}},
|
||||
{"Posterize", {Linspace(0.0, num_bins - 1, num_bins, -4 / (num_bins - 1), 8), false}},
|
||||
{"Solarize", {Linspace(256.0, 0.0, num_bins), false}},
|
||||
{"AutoContrast", {{0}, false}},
|
||||
{"Equalize", {{0}, false}},
|
||||
{"Invert", {{0}, false}}};
|
||||
return space;
|
||||
}
|
||||
|
||||
Status AutoAugmentOp::ApplyAugment(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
||||
const std::string &op_name, float magnitude) {
|
||||
if (op_name == "ShearX") {
|
||||
float_t shear = magnitude * 180 / CV_PI;
|
||||
AffineOp affine(0.0, {0, 0}, 1.0, {shear, 0.0}, interpolation_, fill_value_);
|
||||
RETURN_IF_NOT_OK(affine.Compute(input, output));
|
||||
} else if (op_name == "ShearY") {
|
||||
float_t shear = magnitude * 180 / CV_PI;
|
||||
AffineOp affine(0.0, {0, 0}, 1.0, {0.0, shear}, interpolation_, fill_value_);
|
||||
RETURN_IF_NOT_OK(affine.Compute(input, output));
|
||||
} else if (op_name == "TranslateX") {
|
||||
float_t translate = static_cast<int>(magnitude);
|
||||
AffineOp affine(0.0, {translate, 0}, 1.0, {0.0, 0.0}, interpolation_, fill_value_);
|
||||
RETURN_IF_NOT_OK(affine.Compute(input, output));
|
||||
} else if (op_name == "TranslateY") {
|
||||
float_t translate = static_cast<int>(magnitude);
|
||||
AffineOp affine(0.0, {0, translate}, 1.0, {0.0, 0.0}, interpolation_, fill_value_);
|
||||
RETURN_IF_NOT_OK(affine.Compute(input, output));
|
||||
} else if (op_name == "Rotate") {
|
||||
const int kRIndex = 0;
|
||||
const int kBIndex = 1;
|
||||
const int kGIndex = 2;
|
||||
RETURN_IF_NOT_OK(Rotate(input, output, {}, magnitude, interpolation_, false, fill_value_[kRIndex],
|
||||
fill_value_[kBIndex], fill_value_[kGIndex]));
|
||||
} else if (op_name == "Brightness") {
|
||||
RETURN_IF_NOT_OK(AdjustBrightness(input, output, 1 + magnitude));
|
||||
} else if (op_name == "Color") {
|
||||
RETURN_IF_NOT_OK(AdjustSaturation(input, output, 1 + magnitude));
|
||||
} else if (op_name == "Contrast") {
|
||||
RETURN_IF_NOT_OK(AdjustContrast(input, output, 1 + magnitude));
|
||||
} else if (op_name == "Sharpness") {
|
||||
SharpnessOp sharpness(1 + magnitude);
|
||||
RETURN_IF_NOT_OK(sharpness.Compute(input, output));
|
||||
} else if (op_name == "Posterize") {
|
||||
PosterizeOp posterize(static_cast<int>(magnitude));
|
||||
RETURN_IF_NOT_OK(posterize.Compute(input, output));
|
||||
} else if (op_name == "Solarize") {
|
||||
SolarizeOp solarize({static_cast<uint8_t>(magnitude), 255});
|
||||
RETURN_IF_NOT_OK(solarize.Compute(input, output));
|
||||
} else if (op_name == "AutoContrast") {
|
||||
RETURN_IF_NOT_OK(AutoContrast(input, output, 0.0, {}));
|
||||
} else if (op_name == "Equalize") {
|
||||
RETURN_IF_NOT_OK(Equalize(input, output));
|
||||
} else {
|
||||
InvertOp invert;
|
||||
RETURN_IF_NOT_OK(invert.Compute(input, output));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_AUTO_AUGMENT_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_AUTO_AUGMENT_OP_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#else
|
||||
#include "minddata/dataset/kernels/image/lite_image_utils.h"
|
||||
#endif
|
||||
#include "minddata/dataset/kernels/image/math_utils.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
typedef std::vector<std::vector<std::tuple<std::string, float, int32_t>>> Transforms;
|
||||
typedef std::map<std::string, std::tuple<std::vector<float>, bool>> Space;
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class AutoAugmentOp : public TensorOp {
|
||||
public:
|
||||
AutoAugmentOp(AutoAugmentPolicy policy, InterpolationMode interpolation, const std::vector<uint8_t> &fill_value);
|
||||
|
||||
~AutoAugmentOp() override = default;
|
||||
|
||||
std::string Name() const override { return kAutoAugmentOp; }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
private:
|
||||
void GetParams(int transform_num, int *transform_id, std::vector<float> *probs, std::vector<int32_t> *signs);
|
||||
|
||||
Transforms GetTransforms(AutoAugmentPolicy policy);
|
||||
|
||||
Space GetSpace(int32_t num_bins, const std::vector<dsize_t> &image_size);
|
||||
|
||||
Status ApplyAugment(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::string &op_name,
|
||||
float magnitude);
|
||||
|
||||
AutoAugmentPolicy policy_;
|
||||
InterpolationMode interpolation_;
|
||||
std::vector<uint8_t> fill_value_;
|
||||
std::mt19937 rnd_;
|
||||
Transforms transforms_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_AUTO_AUGMENT_OP_H_
|
|
@ -4,6 +4,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
|
|||
set(DATASET_KERNELS_IR_VISION_SRC_FILES
|
||||
adjust_gamma_ir.cc
|
||||
affine_ir.cc
|
||||
auto_augment_ir.cc
|
||||
auto_contrast_ir.cc
|
||||
bounding_box_augment_ir.cc
|
||||
center_crop_ir.cc
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/kernels/ir/vision/auto_augment_ir.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/auto_augment_op.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/kernels/ir/validators.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace vision {
|
||||
#ifndef ENABLE_ANDROID
|
||||
// AutoAugmentOperation
|
||||
AutoAugmentOperation::AutoAugmentOperation(AutoAugmentPolicy policy, InterpolationMode interpolation,
|
||||
const std::vector<uint8_t> &fill_value)
|
||||
: policy_(policy), interpolation_(interpolation), fill_value_(fill_value) {}
|
||||
|
||||
AutoAugmentOperation::~AutoAugmentOperation() = default;
|
||||
|
||||
std::string AutoAugmentOperation::Name() const { return kAutoAugmentOperation; }
|
||||
|
||||
Status AutoAugmentOperation::ValidateParams() {
|
||||
if (policy_ != AutoAugmentPolicy::kImageNet && policy_ != AutoAugmentPolicy::kCifar10 &&
|
||||
policy_ != AutoAugmentPolicy::kSVHN) {
|
||||
std::string err_msg = "AutoAugment: Invalid AutoAugmentPolicy, check input value of enum.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (interpolation_ != InterpolationMode::kLinear && interpolation_ != InterpolationMode::kNearestNeighbour &&
|
||||
interpolation_ != InterpolationMode::kCubic && interpolation_ != InterpolationMode::kArea) {
|
||||
std::string err_msg = "AutoAugment: Invalid InterpolationMode, check input value of enum.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
RETURN_IF_NOT_OK(ValidateVectorFillvalue("AutoAugment", fill_value_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> AutoAugmentOperation::Build() {
|
||||
std::shared_ptr<AutoAugmentOp> tensor_op = std::make_shared<AutoAugmentOp>(policy_, interpolation_, fill_value_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status AutoAugmentOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["policy"] = policy_;
|
||||
args["interpolation"] = interpolation_;
|
||||
args["fill_value"] = fill_value_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AutoAugmentOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("policy") != op_params.end(), "Failed to find degrees");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("interpolation") != op_params.end(), "Failed to find translate");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("fill_value") != op_params.end(), "Failed to find scale");
|
||||
AutoAugmentPolicy policy = op_params["policy"];
|
||||
InterpolationMode interpolation = op_params["interpolation"];
|
||||
std::vector<uint8_t> fill_value = op_params["fill_value"];
|
||||
*operation = std::make_shared<vision::AutoAugmentOperation>(policy, interpolation, fill_value);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_AUTO_AUGMENT_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_AUTO_AUGMENT_IR_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
#include "minddata/dataset/kernels/image/auto_augment_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace vision {
|
||||
constexpr char kAutoAugmentOperation[] = "AutoAugment";
|
||||
|
||||
class AutoAugmentOperation : public TensorOperation {
|
||||
public:
|
||||
AutoAugmentOperation(AutoAugmentPolicy policy, InterpolationMode interpolation,
|
||||
const std::vector<uint8_t> &fill_value);
|
||||
|
||||
~AutoAugmentOperation();
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override;
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
||||
|
||||
private:
|
||||
AutoAugmentPolicy policy_;
|
||||
InterpolationMode interpolation_;
|
||||
std::vector<uint8_t> fill_value_;
|
||||
};
|
||||
} // namespace vision
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_AUTO_AUGMENT_IR_H_
|
|
@ -55,6 +55,7 @@ constexpr char kTensorOp[] = "TensorOp";
|
|||
// image
|
||||
constexpr char kAdjustGammaOp[] = "AdjustGammaOp";
|
||||
constexpr char kAffineOp[] = "AffineOp";
|
||||
constexpr char kAutoAugmentOp[] = "AutoAugmentOp";
|
||||
constexpr char kAutoContrastOp[] = "AutoContrastOp";
|
||||
constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp";
|
||||
constexpr char kDecodeOp[] = "DecodeOp";
|
||||
|
|
|
@ -27,4 +27,4 @@ Common imported modules in corresponding API examples are as follows:
|
|||
"""
|
||||
from . import c_transforms
|
||||
from . import py_transforms
|
||||
from .utils import Inter, Border, ImageBatchFormat, SliceMode
|
||||
from .utils import Inter, Border, ImageBatchFormat, SliceMode, AutoAugmentPolicy
|
||||
|
|
|
@ -47,12 +47,12 @@ import numpy as np
|
|||
from PIL import Image
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode
|
||||
from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode, AutoAugmentPolicy
|
||||
from .validators import check_prob, check_crop, check_center_crop, check_resize_interpolation, \
|
||||
check_mix_up_batch_c, check_normalize_c, check_normalizepad_c, check_random_crop, check_random_color_adjust, \
|
||||
check_random_rotation, check_range, check_resize, check_rescale, check_pad, check_cutout, check_alpha, \
|
||||
check_uniform_augment_cpp, check_convert_color, check_random_resize_crop, check_random_auto_contrast, \
|
||||
check_random_adjust_sharpness, \
|
||||
check_random_adjust_sharpness, check_auto_augment, \
|
||||
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
|
||||
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \
|
||||
check_cut_mix_batch_c, check_posterize, check_gaussian_blur, check_rotate, check_slice_patches, check_adjust_gamma
|
||||
|
@ -76,6 +76,10 @@ class ImageTensorOperation(TensorOperation):
|
|||
"ImageTensorOperation has to implement parse() method.")
|
||||
|
||||
|
||||
DE_C_AUTO_AUGMENT_POLICY = {AutoAugmentPolicy.IMAGENET: cde.AutoAugmentPolicy.DE_AUTO_AUGMENT_POLICY_IMAGENET,
|
||||
AutoAugmentPolicy.CIFAR10: cde.AutoAugmentPolicy.DE_AUTO_AUGMENT_POLICY_CIFAR10,
|
||||
AutoAugmentPolicy.SVHN: cde.AutoAugmentPolicy.DE_AUTO_AUGMENT_POLICY_SVHN}
|
||||
|
||||
DE_C_BORDER_TYPE = {Border.CONSTANT: cde.BorderType.DE_BORDER_CONSTANT,
|
||||
Border.EDGE: cde.BorderType.DE_BORDER_EDGE,
|
||||
Border.REFLECT: cde.BorderType.DE_BORDER_REFLECT,
|
||||
|
@ -161,6 +165,62 @@ class AdjustGamma(ImageTensorOperation):
|
|||
return cde.AdjustGammaOperation(self.gamma, self.gain)
|
||||
|
||||
|
||||
class AutoAugment(ImageTensorOperation):
|
||||
"""
|
||||
Apply AutoAugment data augmentation method based on
|
||||
`AutoAugment: Learning Augmentation Strategies from Data <https://arxiv.org/pdf/1805.09501.pdf>`_.
|
||||
This operation works only with 3-channel RGB images.
|
||||
|
||||
Args:
|
||||
policy (AutoAugmentPolicy, optional): AutoAugment policies learned on different datasets
|
||||
(default=AutoAugmentPolicy.IMAGENET).
|
||||
It can be any of [AutoAugmentPolicy.IMAGENET, AutoAugmentPolicy.CIFAR10, AutoAugmentPolicy.SVHN].
|
||||
Randomly apply 2 operations from a candidate set. See auto augmentation details in AutoAugmentPolicy.
|
||||
|
||||
- AutoAugmentPolicy.IMAGENET, means to apply AutoAugment learned on ImageNet dataset.
|
||||
|
||||
- AutoAugmentPolicy.CIFAR10, means to apply AutoAugment learned on Cifar10 dataset.
|
||||
|
||||
- AutoAugmentPolicy.SVHN, means to apply AutoAugment learned on SVHN dataset.
|
||||
|
||||
interpolation (Inter, optional): Image interpolation mode for Resize operator (default=Inter.NEAREST).
|
||||
It can be any of [Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC, Inter.AREA].
|
||||
|
||||
- Inter.NEAREST: means interpolation method is nearest-neighbor interpolation.
|
||||
|
||||
- Inter.BILINEAR: means interpolation method is bilinear interpolation.
|
||||
|
||||
- Inter.BICUBIC: means the interpolation method is bicubic interpolation.
|
||||
|
||||
- Inter.AREA: means the interpolation method is area interpolation.
|
||||
|
||||
fill_value (Union[int, tuple], optional): Pixel fill value for the area outside the transformed image.
|
||||
It can be an int or a 3-tuple. If it is a 3-tuple, it is used to fill R, G, B channels respectively.
|
||||
If it is an integer, it is used for all RGB channels. The fill_value values must be in range [0, 255]
|
||||
(default=0).
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.dataset.vision import AutoAugmentPolicy, Inter
|
||||
>>> transforms_list = [c_vision.Decode(), c_vision.AutoAugment(policy=AutoAugmentPolicy.IMAGENET,
|
||||
... interpolation=Inter.NEAREST,
|
||||
... fill_value=0)]
|
||||
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
|
||||
... input_columns=["image"])
|
||||
"""
|
||||
|
||||
@check_auto_augment
|
||||
def __init__(self, policy=AutoAugmentPolicy.IMAGENET, interpolation=Inter.NEAREST, fill_value=0):
|
||||
self.policy = policy
|
||||
self.interpolation = interpolation
|
||||
if isinstance(fill_value, int):
|
||||
fill_value = tuple([fill_value] * 3)
|
||||
self.fill_value = fill_value
|
||||
|
||||
def parse(self):
|
||||
return cde.AutoAugmentOperation(DE_C_AUTO_AUGMENT_POLICY[self.policy], DE_C_INTER_MODE[self.interpolation],
|
||||
self.fill_value)
|
||||
|
||||
|
||||
class AutoContrast(ImageTensorOperation):
|
||||
"""
|
||||
Apply automatic contrast on input image. This operator calculates histogram of image, reassign cutoff percent
|
||||
|
|
|
@ -109,3 +109,67 @@ class SliceMode(IntEnum):
|
|||
"""
|
||||
PAD = 0
|
||||
DROP = 1
|
||||
|
||||
|
||||
class AutoAugmentPolicy(IntEnum):
|
||||
"""
|
||||
AutoAugment policy for different datasets.
|
||||
|
||||
Possible enumeration values are: AutoAugmentPolicy.IMAGENET, AutoAugmentPolicy.CIFAR10,
|
||||
AutoAugmentPolicy.SVHN.
|
||||
|
||||
Each policy contains 25 pairs of augmentation operations. When using AutoAugment, each image is randomly
|
||||
transformed with one of these operation pairs. Each pair has 2 different operations. The following shows
|
||||
all of these augmentation operations, including operation names with their probabilities and random params.
|
||||
|
||||
- AutoAugmentPolicy.IMAGENET: dataset auto augment policy for ImageNet.
|
||||
Augmentation operations pair:
|
||||
[(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
|
||||
(("Equalize", 0.8, None), ("Equalize", 0.6, None)), (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
|
||||
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)), (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
|
||||
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)), (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
|
||||
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
|
||||
(("Rotate", 0.8, 8), ("Color", 0.4, 0)), (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
|
||||
(("Equalize", 0.0, None), ("Equalize", 0.8, None)), (("Invert", 0.6, None), ("Equalize", 1.0, None)),
|
||||
(("Color", 0.6, 4), ("Contrast", 1.0, 8)), (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
|
||||
(("Color", 0.8, 8), ("Solarize", 0.8, 7)), (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
|
||||
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)), (("Color", 0.4, 0), ("Equalize", 0.6, None)),
|
||||
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)), (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
|
||||
(("Invert", 0.6, None), ("Equalize", 1.0, None)), (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
|
||||
(("Equalize", 0.8, None), ("Equalize", 0.6, None))]
|
||||
|
||||
- AutoAugmentPolicy.CIFAR10: dataset auto augment policy for Cifar10.
|
||||
Augmentation operations pair:
|
||||
[(("Invert", 0.1, None), ("Contrast", 0.2, 6)), (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
|
||||
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
|
||||
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
|
||||
(("Color", 0.4, 3), ("Brightness", 0.6, 7)), (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
|
||||
(("Equalize", 0.6, None), ("Equalize", 0.5, None)), (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
|
||||
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)), (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
|
||||
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
|
||||
(("Solarize", 0.5, 2), ("Invert", 0.0, None)), (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
|
||||
(("Equalize", 0.2, None), ("Equalize", 0.6, None)), (("Color", 0.9, 9), ("Equalize", 0.6, None)),
|
||||
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
|
||||
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
|
||||
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), (("Equalize", 0.8, None), ("Invert", 0.1, None)),
|
||||
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None))]
|
||||
|
||||
- AutoAugmentPolicy.SVHN: dataset auto augment policy for SVHN.
|
||||
Augmentation operations pair:
|
||||
[(("ShearX", 0.9, 4), ("Invert", 0.2, None)), (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
|
||||
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)), (("Invert", 0.9, None), ("Equalize", 0.6, None)),
|
||||
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)), (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
|
||||
(("ShearY", 0.9, 8), ("Invert", 0.4, None)), (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
|
||||
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)), (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
|
||||
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
|
||||
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), (("Invert", 0.9, None), ("Equalize", 0.6, None)),
|
||||
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
|
||||
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
|
||||
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
|
||||
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
|
||||
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
|
||||
(("ShearX", 0.7, 2), ("Invert", 0.1, None))]
|
||||
"""
|
||||
IMAGENET = 0
|
||||
CIFAR10 = 1
|
||||
SVHN = 2
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.dataset.core.validator_helpers import check_value, check_uint8, F
|
|||
check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \
|
||||
parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, check_value_normalize_std, \
|
||||
check_value_cutoff, check_value_ratio, check_odd, check_non_negative_float32
|
||||
from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode
|
||||
from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode, AutoAugmentPolicy
|
||||
|
||||
|
||||
def check_crop_size(size):
|
||||
|
@ -1031,3 +1031,18 @@ def check_convert_color(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_auto_augment(method):
|
||||
"""Wrapper method to check the parameters of AutoAugment."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[policy, interpolation, fill_value], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
type_check(policy, (AutoAugmentPolicy,), "policy")
|
||||
type_check(interpolation, (Inter,), "interpolation")
|
||||
check_fill_value(fill_value)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -1270,3 +1270,57 @@ TEST_F(MindDataTestPipeline, TestConvertColorFail) {
|
|||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: AutoAugment
|
||||
/// Description: test AutoAugment pipeline
|
||||
/// Expectation: create an ImageFolder dataset then do auto augmentation on it with the policy
|
||||
TEST_F(MindDataTestPipeline, TestAutoAugment) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAutoAugment.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto auto_augment_op = vision::AutoAugment(AutoAugmentPolicy::kImageNet, InterpolationMode::kLinear, {0, 0, 0});
|
||||
|
||||
ds = ds->Map({auto_augment_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: AutoAugment
|
||||
/// Description: test AutoAugment with invalid fill_value
|
||||
/// Expectation: pipeline iteration failed with wrong argument fill_value
|
||||
TEST_F(MindDataTestPipeline, TestAutoAugmentInvalidFillValue) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAutoAugmentInvalidFillValue.";
|
||||
|
||||
std::string MindDataPath = "data/dataset";
|
||||
std::string folder_path = MindDataPath + "/testImageNetData/train/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 2));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto auto_augment_op = vision::AutoAugment(AutoAugmentPolicy::kImageNet,
|
||||
InterpolationMode::kNearestNeighbour, {20, 20});
|
||||
|
||||
ds = ds->Map({auto_augment_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
|
|
@ -1564,3 +1564,20 @@ TEST_F(MindDataTestExecute, TestSlidingWindowCmnWrongArgs) {
|
|||
Status status_2 = Transform_2(input_ms, &input_ms);
|
||||
EXPECT_FALSE(status_2.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: AutoAugment
|
||||
/// Description: test AutoAugment eager
|
||||
/// Expectation: load one image data and process auto augmentation with given policy on it.
|
||||
TEST_F(MindDataTestExecute, TestAutoAugmentEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestAutoAugmentEager.";
|
||||
// Read images
|
||||
auto image = ReadFileToTensor("data/dataset/apple.jpg");
|
||||
|
||||
// Transform params
|
||||
auto decode = vision::Decode();
|
||||
auto auto_augment_op = vision::AutoAugment(AutoAugmentPolicy::kImageNet, InterpolationMode::kLinear, {0, 0, 0});
|
||||
|
||||
auto transform = Execute({decode, auto_augment_op});
|
||||
Status rc = transform(image, &image);
|
||||
EXPECT_EQ(rc, Status::OK());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,197 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing AutoAugment in DE
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.dataset.vision.c_transforms import Decode, AutoAugment, Resize
|
||||
from mindspore.dataset.vision.utils import AutoAugmentPolicy, Inter
|
||||
from mindspore import log as logger
|
||||
from util import visualize_image, visualize_list, diff_mse
|
||||
|
||||
image_file = "../data/dataset/testImageNetData/train/class1/1_1.jpg"
|
||||
data_dir = "../data/dataset/testImageNetData/train/"
|
||||
|
||||
|
||||
def test_auto_augment_pipeline(plot=False):
|
||||
"""
|
||||
Feature: AutoAugment
|
||||
Description: test AutoAugment pipeline
|
||||
Expectation: pass without error
|
||||
"""
|
||||
logger.info("Test AutoAugment pipeline")
|
||||
|
||||
# Original Images
|
||||
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
transforms_original = [Decode(), Resize(size=[224, 224])]
|
||||
ds_original = data_set.map(operations=transforms_original, input_columns="image")
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = image.asnumpy()
|
||||
else:
|
||||
images_original = np.append(images_original,
|
||||
image.asnumpy(),
|
||||
axis=0)
|
||||
|
||||
# Auto Augmented Images with ImageNet policy
|
||||
data_set1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
auto_augment_op = AutoAugment(AutoAugmentPolicy.IMAGENET, Inter.BICUBIC, 20)
|
||||
transforms = [Decode(), Resize(size=[224, 224]), auto_augment_op]
|
||||
ds_auto_augment = data_set1.map(operations=transforms, input_columns="image")
|
||||
ds_auto_augment = ds_auto_augment.batch(512)
|
||||
for idx, (image, _) in enumerate(ds_auto_augment):
|
||||
if idx == 0:
|
||||
images_auto_augment = image.asnumpy()
|
||||
else:
|
||||
images_auto_augment = np.append(images_auto_augment,
|
||||
image.asnumpy(),
|
||||
axis=0)
|
||||
assert images_original.shape[0] == images_auto_augment.shape[0]
|
||||
if plot:
|
||||
visualize_list(images_original, images_auto_augment)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_auto_augment[i], images_original[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
# Auto Augmented Images with Cifar10 policy
|
||||
data_set2 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
auto_augment_op = AutoAugment(AutoAugmentPolicy.CIFAR10, Inter.BILINEAR, 20)
|
||||
transforms = [Decode(), Resize(size=[224, 224]), auto_augment_op]
|
||||
ds_auto_augment = data_set2.map(operations=transforms, input_columns="image")
|
||||
ds_auto_augment = ds_auto_augment.batch(512)
|
||||
for idx, (image, _) in enumerate(ds_auto_augment):
|
||||
if idx == 0:
|
||||
images_auto_augment = image.asnumpy()
|
||||
else:
|
||||
images_auto_augment = np.append(images_auto_augment,
|
||||
image.asnumpy(),
|
||||
axis=0)
|
||||
assert images_original.shape[0] == images_auto_augment.shape[0]
|
||||
if plot:
|
||||
visualize_list(images_original, images_auto_augment)
|
||||
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_auto_augment[i], images_original[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
# Auto Augmented Images with SVHN policy
|
||||
data_set3 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
auto_augment_op = AutoAugment(AutoAugmentPolicy.SVHN, Inter.NEAREST, 20)
|
||||
transforms = [Decode(), Resize(size=[224, 224]), auto_augment_op]
|
||||
ds_auto_augment = data_set3.map(operations=transforms, input_columns="image")
|
||||
ds_auto_augment = ds_auto_augment.batch(512)
|
||||
for idx, (image, _) in enumerate(ds_auto_augment):
|
||||
if idx == 0:
|
||||
images_auto_augment = image.asnumpy()
|
||||
else:
|
||||
images_auto_augment = np.append(images_auto_augment,
|
||||
image.asnumpy(),
|
||||
axis=0)
|
||||
assert images_original.shape[0] == images_auto_augment.shape[0]
|
||||
if plot:
|
||||
visualize_list(images_original, images_auto_augment)
|
||||
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_auto_augment[i], images_original[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
|
||||
def test_auto_augment_eager(plot=False):
|
||||
"""
|
||||
Feature: AutoAugment
|
||||
Description: test AutoAugment eager
|
||||
Expectation: pass without error
|
||||
"""
|
||||
img = np.fromfile(image_file, dtype=np.uint8)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
|
||||
|
||||
img = Decode()(img)
|
||||
img_auto_augmented = AutoAugment()(img)
|
||||
if plot:
|
||||
visualize_image(img, img_auto_augmented)
|
||||
logger.info("Image.type: {}, Image.shape: {}".format(type(img_auto_augmented), img_auto_augmented.shape))
|
||||
mse = diff_mse(img_auto_augmented, img)
|
||||
logger.info("MSE= {}".format(str(mse)))
|
||||
|
||||
|
||||
def test_auto_augment_invalid_policy():
|
||||
"""
|
||||
Feature: AutoAugment
|
||||
Description: test AutoAugment with invalid policy
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
logger.info("test_auto_augment_invalid_policy")
|
||||
dataset = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
|
||||
try:
|
||||
auto_augment_op = AutoAugment(policy="invalid")
|
||||
dataset.map(operations=auto_augment_op, input_columns=['image'])
|
||||
except TypeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Argument policy with value invalid is not of type [<enum 'AutoAugmentPolicy'>]" in str(e)
|
||||
|
||||
|
||||
def test_auto_augment_invalid_interpolation():
|
||||
"""
|
||||
Feature: AutoAugment
|
||||
Description: test AutoAugment with invalid interpolation
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
logger.info("test_auto_augment_invalid_interpolation")
|
||||
dataset = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
|
||||
try:
|
||||
auto_augment_op = AutoAugment(interpolation="invalid")
|
||||
dataset.map(operations=auto_augment_op, input_columns=['image'])
|
||||
except TypeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Argument interpolation with value invalid is not of type [<enum 'Inter'>]" in str(e)
|
||||
|
||||
|
||||
def test_auto_augment_invalid_fill_value():
|
||||
"""
|
||||
Feature: AutoAugment
|
||||
Description: test AutoAugment with invalid fill_value
|
||||
Expectation: throw TypeError or ValueError
|
||||
"""
|
||||
logger.info("test_auto_augment_invalid_fill_value")
|
||||
dataset = ds.ImageFolderDataset(data_dir, 1, shuffle=False, decode=True)
|
||||
try:
|
||||
auto_augment_op = AutoAugment(fill_value=(10, 10))
|
||||
dataset.map(operations=auto_augment_op, input_columns=['image'])
|
||||
except TypeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "fill_value should be a single integer or a 3-tuple." in str(e)
|
||||
try:
|
||||
auto_augment_op = AutoAugment(fill_value=300)
|
||||
dataset.map(operations=auto_augment_op, input_columns=['image'])
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "is not within the required interval of [0, 255]." in str(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_auto_augment_pipeline(plot=True)
|
||||
test_auto_augment_eager(plot=True)
|
||||
test_auto_augment_invalid_policy()
|
||||
test_auto_augment_invalid_interpolation()
|
||||
test_auto_augment_invalid_fill_value()
|
Loading…
Reference in New Issue