forked from mindspore-Ecosystem/mindspore
Added CutMix
This commit is contained in:
parent
04056cf8bc
commit
3ecc53fb4e
|
@ -110,5 +110,12 @@ PYBIND_REGISTER(InterpolationMode, 0, ([](const py::module *m) {
|
|||
.export_values();
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ImageBatchFormat, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<ImageBatchFormat>(*m, "ImageBatchFormat", py::arithmetic())
|
||||
.value("DE_IMAGE_BATCH_FORMAT_NHWC", ImageBatchFormat::kNHWC)
|
||||
.value("DE_IMAGE_BATCH_FORMAT_NCHW", ImageBatchFormat::kNCHW)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "minddata/dataset/kernels/image/auto_contrast_op.h"
|
||||
#include "minddata/dataset/kernels/image/bounding_box_augment_op.h"
|
||||
#include "minddata/dataset/kernels/image/center_crop_op.h"
|
||||
#include "minddata/dataset/kernels/image/cutmix_batch_op.h"
|
||||
#include "minddata/dataset/kernels/image/cut_out_op.h"
|
||||
#include "minddata/dataset/kernels/image/decode_op.h"
|
||||
#include "minddata/dataset/kernels/image/equalize_op.h"
|
||||
|
@ -104,6 +105,13 @@ PYBIND_REGISTER(MixUpBatchOp, 1, ([](const py::module *m) {
|
|||
.def(py::init<float>(), py::arg("alpha"));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(CutMixBatchOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<CutMixBatchOp, TensorOp, std::shared_ptr<CutMixBatchOp>>(
|
||||
*m, "CutMixBatchOp", "Tensor operation to cutmix a batch of images")
|
||||
.def(py::init<ImageBatchFormat, float, float>(), py::arg("image_batch_format"), py::arg("alpha"),
|
||||
py::arg("prob"));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ResizeOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<ResizeOp, TensorOp, std::shared_ptr<ResizeOp>>(
|
||||
*m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode")
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include "minddata/dataset/kernels/image/center_crop_op.h"
|
||||
#include "minddata/dataset/kernels/image/crop_op.h"
|
||||
#include "minddata/dataset/kernels/image/cutmix_batch_op.h"
|
||||
#include "minddata/dataset/kernels/image/cut_out_op.h"
|
||||
#include "minddata/dataset/kernels/image/decode_op.h"
|
||||
#include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
|
||||
|
@ -69,6 +70,16 @@ std::shared_ptr<CropOperation> Crop(std::vector<int32_t> coordinates, std::vecto
|
|||
return op;
|
||||
}
|
||||
|
||||
// Function to create CutMixBatchOperation.
|
||||
std::shared_ptr<CutMixBatchOperation> CutMixBatch(ImageBatchFormat image_batch_format, float alpha, float prob) {
|
||||
auto op = std::make_shared<CutMixBatchOperation>(image_batch_format, alpha, prob);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create CutOutOp.
|
||||
std::shared_ptr<CutOutOperation> CutOut(int32_t length, int32_t num_patches) {
|
||||
auto op = std::make_shared<CutOutOperation>(length, num_patches);
|
||||
|
@ -339,6 +350,27 @@ std::shared_ptr<TensorOp> CropOperation::Build() {
|
|||
return tensor_op;
|
||||
}
|
||||
|
||||
// CutMixBatchOperation
|
||||
CutMixBatchOperation::CutMixBatchOperation(ImageBatchFormat image_batch_format, float alpha, float prob)
|
||||
: image_batch_format_(image_batch_format), alpha_(alpha), prob_(prob) {}
|
||||
|
||||
bool CutMixBatchOperation::ValidateParams() {
|
||||
if (alpha_ < 0) {
|
||||
MS_LOG(ERROR) << "CutMixBatch: alpha cannot be negative.";
|
||||
return false;
|
||||
}
|
||||
if (prob_ < 0 || prob_ > 1) {
|
||||
MS_LOG(ERROR) << "CutMixBatch: Probability has to be between 0 and 1.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> CutMixBatchOperation::Build() {
|
||||
std::shared_ptr<CutMixBatchOp> tensor_op = std::make_shared<CutMixBatchOp>(image_batch_format_, alpha_, prob_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
// CutOutOperation
|
||||
CutOutOperation::CutOutOperation(int32_t length, int32_t num_patches) : length_(length), num_patches_(num_patches) {}
|
||||
|
||||
|
|
|
@ -41,6 +41,12 @@ enum class ShuffleMode { kFalse = 0, kFiles = 1, kGlobal = 2 };
|
|||
// Possible values for Border types
|
||||
enum class BorderType { kConstant = 0, kEdge = 1, kReflect = 2, kSymmetric = 3 };
|
||||
|
||||
// Possible values for Image format types in a batch
|
||||
enum class ImageBatchFormat { kNHWC = 0, kNCHW = 1 };
|
||||
|
||||
// Possible values for Image format types
|
||||
enum class ImageFormat { HWC = 0, CHW = 1, HW = 2 };
|
||||
|
||||
// Possible interpolation modes
|
||||
enum class InterpolationMode { kLinear = 0, kNearestNeighbour = 1, kCubic = 2, kArea = 3 };
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ namespace vision {
|
|||
// Transform Op classes (in alphabetical order)
|
||||
class CenterCropOperation;
|
||||
class CropOperation;
|
||||
class CutMixBatchOperation;
|
||||
class CutOutOperation;
|
||||
class DecodeOperation;
|
||||
class HwcToChwOperation;
|
||||
|
@ -85,6 +86,16 @@ std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size);
|
|||
/// \return Shared pointer to the current TensorOp
|
||||
std::shared_ptr<CropOperation> Crop(std::vector<int32_t> coordinates, std::vector<int32_t> size);
|
||||
|
||||
/// \brief Function to apply CutMix on a batch of images
|
||||
/// \notes Masks a random section of each image with the corresponding part of another randomly selected image in
|
||||
/// that batch
|
||||
/// \param[in] image_batch_format The format of the batch
|
||||
/// \param[in] alpha The hyperparameter of beta distribution (default = 1.0)
|
||||
/// \param[in] prob The probability by which CutMix is applied to each image (default = 1.0)
|
||||
/// \return Shared pointer to the current TensorOp
|
||||
std::shared_ptr<CutMixBatchOperation> CutMixBatch(ImageBatchFormat image_batch_format, float alpha = 1.0,
|
||||
float prob = 1.0);
|
||||
|
||||
/// \brief Function to create a CutOut TensorOp
|
||||
/// \notes Randomly cut (mask) out a given number of square patches from the input image
|
||||
/// \param[in] length Integer representing the side length of each square patch
|
||||
|
@ -296,6 +307,22 @@ class CropOperation : public TensorOperation {
|
|||
std::vector<int32_t> size_;
|
||||
};
|
||||
|
||||
class CutMixBatchOperation : public TensorOperation {
|
||||
public:
|
||||
explicit CutMixBatchOperation(ImageBatchFormat image_batch_format, float alpha = 1.0, float prob = 1.0);
|
||||
|
||||
~CutMixBatchOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float prob_;
|
||||
ImageBatchFormat image_batch_format_;
|
||||
};
|
||||
|
||||
class CutOutOperation : public TensorOperation {
|
||||
public:
|
||||
explicit CutOutOperation(int32_t length, int32_t num_patches = 1);
|
||||
|
@ -309,6 +336,7 @@ class CutOutOperation : public TensorOperation {
|
|||
private:
|
||||
int32_t length_;
|
||||
int32_t num_patches_;
|
||||
ImageBatchFormat image_batch_format_;
|
||||
};
|
||||
|
||||
class DecodeOperation : public TensorOperation {
|
||||
|
|
|
@ -655,7 +655,7 @@ Status BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> &input,
|
|||
TensorShape remaining({-1});
|
||||
std::vector<int64_t> index(tensor_shape.size(), 0);
|
||||
if (tensor_shape.size() <= 1) {
|
||||
RETURN_STATUS_UNEXPECTED("Tensor must be at least 2-D in order to unpack");
|
||||
RETURN_STATUS_UNEXPECTED("Tensor must be at least 2-D in order to unpack.");
|
||||
}
|
||||
TensorShape element_shape(std::vector<int64_t>(tensor_shape.begin() + 1, tensor_shape.end()));
|
||||
|
||||
|
@ -664,15 +664,48 @@ Status BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> &input,
|
|||
std::shared_ptr<Tensor> out;
|
||||
|
||||
RETURN_IF_NOT_OK(input->StartAddrOfIndex(index, &start_addr_of_index, &remaining));
|
||||
RETURN_IF_NOT_OK(input->CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out));
|
||||
std::shared_ptr<CVTensor> cv_out = CVTensor::AsCVTensor(std::move(out));
|
||||
if (!cv_out->mat().data) {
|
||||
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
|
||||
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor.");
|
||||
}
|
||||
output->push_back(cv_out);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchTensorToTensorVector(const std::shared_ptr<Tensor> &input, std::vector<std::shared_ptr<Tensor>> *output) {
|
||||
std::vector<int64_t> tensor_shape = input->shape().AsVector();
|
||||
TensorShape remaining({-1});
|
||||
std::vector<int64_t> index(tensor_shape.size(), 0);
|
||||
if (tensor_shape.size() <= 1) {
|
||||
RETURN_STATUS_UNEXPECTED("Tensor must be at least 2-D in order to unpack.");
|
||||
}
|
||||
TensorShape element_shape(std::vector<int64_t>(tensor_shape.begin() + 1, tensor_shape.end()));
|
||||
|
||||
for (; index[0] < tensor_shape[0]; index[0]++) {
|
||||
uchar *start_addr_of_index = nullptr;
|
||||
std::shared_ptr<Tensor> out;
|
||||
|
||||
RETURN_IF_NOT_OK(input->StartAddrOfIndex(index, &start_addr_of_index, &remaining));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out));
|
||||
output->push_back(out);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorVectorToBatchTensor(const std::vector<std::shared_ptr<Tensor>> &input, std::shared_ptr<Tensor> *output) {
|
||||
if (input.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("TensorVectorToBatchTensor: Received an empty vector.");
|
||||
}
|
||||
std::vector<int64_t> tensor_shape = input.front()->shape().AsVector();
|
||||
tensor_shape.insert(tensor_shape.begin(), input.size());
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape(tensor_shape), input.at(0)->type(), output));
|
||||
for (int i = 0; i < input.size(); i++) {
|
||||
RETURN_IF_NOT_OK((*output)->InsertTensor({i}, input[i]));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -158,11 +158,24 @@ Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
|
|||
std::shared_ptr<Tensor> append);
|
||||
|
||||
/// Convert an n-dimensional Tensor to a vector of (n-1)-dimensional CVTensors
|
||||
/// @param input[in] input tensor
|
||||
/// @param output[out] output tensor
|
||||
/// @return Status ok/error
|
||||
/// \param input[in] input tensor
|
||||
/// \param output[out] output vector of CVTensors
|
||||
/// \return Status ok/error
|
||||
Status BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> &input,
|
||||
std::vector<std::shared_ptr<CVTensor>> *output);
|
||||
|
||||
/// Convert an n-dimensional Tensor to a vector of (n-1)-dimensional Tensors
|
||||
/// \param input[in] input tensor
|
||||
/// \param output[out] output vector of tensors
|
||||
/// \return Status ok/error
|
||||
Status BatchTensorToTensorVector(const std::shared_ptr<Tensor> &input, std::vector<std::shared_ptr<Tensor>> *output);
|
||||
|
||||
/// Convert a vector of (n-1)-dimensional Tensors to an n-dimensional Tensor
|
||||
/// \param input[in] input vector of tensors
|
||||
/// \param output[out] output tensor
|
||||
/// \return Status ok/error
|
||||
Status TensorVectorToBatchTensor(const std::vector<std::shared_ptr<Tensor>> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ add_library(kernels-image OBJECT
|
|||
center_crop_op.cc
|
||||
crop_op.cc
|
||||
cut_out_op.cc
|
||||
cutmix_batch_op.cc
|
||||
decode_op.cc
|
||||
equalize_op.cc
|
||||
hwc_to_chw_op.cc
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#include "minddata/dataset/kernels/image/cutmix_batch_op.h"
|
||||
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
CutMixBatchOp::CutMixBatchOp(ImageBatchFormat image_batch_format, float alpha, float prob)
|
||||
: image_batch_format_(image_batch_format), alpha_(alpha), prob_(prob) {
|
||||
rnd_.seed(GetSeed());
|
||||
}
|
||||
|
||||
void CutMixBatchOp::GetCropBox(int height, int width, float lam, int *x, int *y, int *crop_width, int *crop_height) {
|
||||
float cut_ratio = 1 - lam;
|
||||
int cut_w = static_cast<int>(width * cut_ratio);
|
||||
int cut_h = static_cast<int>(height * cut_ratio);
|
||||
std::uniform_int_distribution<int> width_uniform_distribution(0, width);
|
||||
std::uniform_int_distribution<int> height_uniform_distribution(0, height);
|
||||
int cx = width_uniform_distribution(rnd_);
|
||||
int x2, y2;
|
||||
int cy = height_uniform_distribution(rnd_);
|
||||
*x = std::clamp(cx - cut_w / 2, 0, width - 1); // horizontal coordinate of left side of crop box
|
||||
*y = std::clamp(cy - cut_h / 2, 0, height - 1); // vertical coordinate of the top side of crop box
|
||||
x2 = std::clamp(cx + cut_w / 2, 0, width - 1); // horizontal coordinate of right side of crop box
|
||||
y2 = std::clamp(cy + cut_h / 2, 0, height - 1); // vertical coordinate of the bottom side of crop box
|
||||
*crop_width = std::clamp(x2 - *x, 1, width - 1);
|
||||
*crop_height = std::clamp(y2 - *y, 1, height - 1);
|
||||
}
|
||||
|
||||
Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||
if (input.size() < 2) {
|
||||
RETURN_STATUS_UNEXPECTED("Both images and labels columns are required for this operation");
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<Tensor>> images;
|
||||
std::vector<int64_t> image_shape = input.at(0)->shape().AsVector();
|
||||
std::vector<int64_t> label_shape = input.at(1)->shape().AsVector();
|
||||
|
||||
// Check inputs
|
||||
if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) {
|
||||
RETURN_STATUS_UNEXPECTED("You must batch before calling CutMixBatch.");
|
||||
}
|
||||
if (label_shape.size() != 2) {
|
||||
RETURN_STATUS_UNEXPECTED("CutMixBatch: Label's must be in one-hot format and in a batch");
|
||||
}
|
||||
if ((image_shape[1] != 1 && image_shape[1] != 3) && image_batch_format_ == ImageBatchFormat::kNCHW) {
|
||||
RETURN_STATUS_UNEXPECTED("CutMixBatch: Image doesn't match the given image format.");
|
||||
}
|
||||
if ((image_shape[3] != 1 && image_shape[3] != 3) && image_batch_format_ == ImageBatchFormat::kNHWC) {
|
||||
RETURN_STATUS_UNEXPECTED("CutMixBatch: Image doesn't match the given image format.");
|
||||
}
|
||||
|
||||
// Move images into a vector of Tensors
|
||||
RETURN_IF_NOT_OK(BatchTensorToTensorVector(input.at(0), &images));
|
||||
|
||||
// Calculate random labels
|
||||
std::vector<int64_t> rand_indx;
|
||||
for (int64_t i = 0; i < images.size(); i++) rand_indx.push_back(i);
|
||||
std::shuffle(rand_indx.begin(), rand_indx.end(), rnd_);
|
||||
|
||||
std::gamma_distribution<float> gamma_distribution(alpha_, 1);
|
||||
std::uniform_real_distribution<double> uniform_distribution(0.0, 1.0);
|
||||
|
||||
// Tensor holding the output labels
|
||||
std::shared_ptr<Tensor> out_labels;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape(label_shape), DataType(DataType::DE_FLOAT32), &out_labels));
|
||||
|
||||
// Compute labels and images
|
||||
for (int i = 0; i < image_shape[0]; i++) {
|
||||
// Calculating lambda
|
||||
// If x1 is a random variable from Gamma(a1, 1) and x2 is a random variable from Gamma(a2, 1)
|
||||
// then x = x1 / (x1+x2) is a random variable from Beta(a1, a2)
|
||||
float x1 = gamma_distribution(rnd_);
|
||||
float x2 = gamma_distribution(rnd_);
|
||||
float lam = x1 / (x1 + x2);
|
||||
double random_number = uniform_distribution(rnd_);
|
||||
if (random_number < prob_) {
|
||||
int x, y, crop_width, crop_height;
|
||||
float label_lam; // lambda used for labels
|
||||
|
||||
// Get a random image
|
||||
TensorShape remaining({-1});
|
||||
uchar *start_addr_of_index = nullptr;
|
||||
std::shared_ptr<Tensor> rand_image;
|
||||
RETURN_IF_NOT_OK(input.at(0)->StartAddrOfIndex({rand_indx[i], 0, 0, 0}, &start_addr_of_index, &remaining));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape({image_shape[1], image_shape[2], image_shape[3]}),
|
||||
input.at(0)->type(), start_addr_of_index, &rand_image));
|
||||
|
||||
// Compute image
|
||||
if (image_batch_format_ == ImageBatchFormat::kNHWC) {
|
||||
// NHWC Format
|
||||
GetCropBox(static_cast<int32_t>(image_shape[1]), static_cast<int32_t>(image_shape[2]), lam, &x, &y, &crop_width,
|
||||
&crop_height);
|
||||
std::shared_ptr<Tensor> cropped;
|
||||
RETURN_IF_NOT_OK(Crop(rand_image, &cropped, x, y, crop_width, crop_height));
|
||||
RETURN_IF_NOT_OK(MaskWithTensor(cropped, &images[i], x, y, crop_width, crop_height, ImageFormat::HWC));
|
||||
label_lam = 1 - (crop_width * crop_height / static_cast<float>(image_shape[1] * image_shape[2]));
|
||||
|
||||
} else {
|
||||
// NCHW Format
|
||||
GetCropBox(static_cast<int32_t>(image_shape[2]), static_cast<int32_t>(image_shape[3]), lam, &x, &y, &crop_width,
|
||||
&crop_height);
|
||||
std::vector<std::shared_ptr<Tensor>> channels; // A vector holding channels of the CHW image
|
||||
std::vector<std::shared_ptr<Tensor>> cropped_channels; // A vector holding the channels of the cropped CHW
|
||||
RETURN_IF_NOT_OK(BatchTensorToTensorVector(rand_image, &channels));
|
||||
for (auto channel : channels) {
|
||||
// Call crop for each single channel
|
||||
std::shared_ptr<Tensor> cropped_channel;
|
||||
RETURN_IF_NOT_OK(Crop(channel, &cropped_channel, x, y, crop_width, crop_height));
|
||||
cropped_channels.push_back(cropped_channel);
|
||||
}
|
||||
std::shared_ptr<Tensor> cropped;
|
||||
// Merge channels to a single tensor
|
||||
RETURN_IF_NOT_OK(TensorVectorToBatchTensor(cropped_channels, &cropped));
|
||||
|
||||
RETURN_IF_NOT_OK(MaskWithTensor(cropped, &images[i], x, y, crop_width, crop_height, ImageFormat::CHW));
|
||||
label_lam = 1 - (crop_width * crop_height / static_cast<float>(image_shape[2] * image_shape[3]));
|
||||
}
|
||||
|
||||
// Compute labels
|
||||
for (int j = 0; j < label_shape[1]; j++) {
|
||||
uint64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j}));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i] % label_shape[0], j}));
|
||||
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, label_lam * first_value + (1 - label_lam) * second_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> out_images;
|
||||
RETURN_IF_NOT_OK(TensorVectorToBatchTensor(images, &out_images));
|
||||
|
||||
// Move the output into a TensorRow
|
||||
output->push_back(out_images);
|
||||
output->push_back(out_labels);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CutMixBatchOp::Print(std::ostream &out) const {
|
||||
out << "CutMixBatchOp: "
|
||||
<< "image_batch_format: " << image_batch_format_ << "alpha: " << alpha_ << ", probability: " << prob_ << "\n";
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUTMIXBATCH_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUTMIXBATCH_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class CutMixBatchOp : public TensorOp {
|
||||
public:
|
||||
explicit CutMixBatchOp(ImageBatchFormat image_batch_format, float alpha, float prob);
|
||||
|
||||
~CutMixBatchOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
void GetCropBox(int width, int height, float lam, int *x, int *y, int *crop_width, int *crop_height);
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kCutMixBatchOp; }
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float prob_;
|
||||
ImageBatchFormat image_batch_format_;
|
||||
std::mt19937 rnd_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUTMIXBATCH_OP_H_
|
|
@ -402,6 +402,62 @@ Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output)
|
|||
}
|
||||
}
|
||||
|
||||
Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Tensor> *input, int x, int y,
|
||||
int crop_width, int crop_height, ImageFormat image_format) {
|
||||
if (image_format == ImageFormat::HWC) {
|
||||
if ((*input)->Rank() != 3 || ((*input)->shape()[2] != 1 && (*input)->shape()[2] != 3)) {
|
||||
RETURN_STATUS_UNEXPECTED("MaskWithTensor: Image shape doesn't match the given image_format.");
|
||||
}
|
||||
if (sub_mat->Rank() != 3 || (sub_mat->shape()[2] != 1 && sub_mat->shape()[2] != 3)) {
|
||||
RETURN_STATUS_UNEXPECTED("MaskWithTensor: sub_mat shape doesn't match the given image_format.");
|
||||
}
|
||||
int number_of_channels = (*input)->shape()[2];
|
||||
for (int i = 0; i < crop_width; i++) {
|
||||
for (int j = 0; j < crop_height; j++) {
|
||||
for (int c = 0; c < number_of_channels; c++) {
|
||||
uint8_t pixel_value;
|
||||
RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {j, i, c}));
|
||||
RETURN_IF_NOT_OK((*input)->SetItemAt({y + j, x + i, c}, pixel_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (image_format == ImageFormat::CHW) {
|
||||
if ((*input)->Rank() != 3 || ((*input)->shape()[0] != 1 && (*input)->shape()[0] != 3)) {
|
||||
RETURN_STATUS_UNEXPECTED("MaskWithTensor: Image shape doesn't match the given image_format.");
|
||||
}
|
||||
if (sub_mat->Rank() != 3 || (sub_mat->shape()[0] != 1 && sub_mat->shape()[0] != 3)) {
|
||||
RETURN_STATUS_UNEXPECTED("MaskWithTensor: sub_mat shape doesn't match the given image_format.");
|
||||
}
|
||||
int number_of_channels = (*input)->shape()[0];
|
||||
for (int i = 0; i < crop_width; i++) {
|
||||
for (int j = 0; j < crop_height; j++) {
|
||||
for (int c = 0; c < number_of_channels; c++) {
|
||||
uint8_t pixel_value;
|
||||
RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {c, j, i}));
|
||||
RETURN_IF_NOT_OK((*input)->SetItemAt({c, y + j, x + i}, pixel_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (image_format == ImageFormat::HW) {
|
||||
if ((*input)->Rank() != 2) {
|
||||
RETURN_STATUS_UNEXPECTED("MaskWithTensor: Image shape doesn't match the given image_format.");
|
||||
}
|
||||
if (sub_mat->Rank() != 2) {
|
||||
RETURN_STATUS_UNEXPECTED("MaskWithTensor: sub_mat shape doesn't match the given image_format.");
|
||||
}
|
||||
for (int i = 0; i < crop_width; i++) {
|
||||
for (int j = 0; j < crop_height; j++) {
|
||||
uint8_t pixel_value;
|
||||
RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {j, i}));
|
||||
RETURN_IF_NOT_OK((*input)->SetItemAt({y + j, x + i}, pixel_value));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("MaskWithTensor: Image format must be CHW, HWC, or HW.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SwapRedAndBlue(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) {
|
||||
try {
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input));
|
||||
|
|
|
@ -120,6 +120,19 @@ Status Crop(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
|
|||
/// \param output: Tensor of shape <C,H,W> or <H,W> and same input type.
|
||||
Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
/// \brief Masks the given part of the input image with a another image (sub_mat)
|
||||
/// \param[in] sub_mat The image we want to mask with
|
||||
/// \param[in] input The pointer to the image we want to mask
|
||||
/// \param[in] x The horizontal coordinate of left side of crop box
|
||||
/// \param[in] y The vertical coordinate of the top side of crop box
|
||||
/// \param[in] width The width of the mask box
|
||||
/// \param[in] height The height of the mask box
|
||||
/// \param[in] image_format The format of the image (CHW or HWC)
|
||||
/// \param[out] input Masks the input image in-place and returns it
|
||||
/// @return Status ok/error
|
||||
Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Tensor> *input, int x, int y, int width,
|
||||
int height, ImageFormat image_format);
|
||||
|
||||
/// \brief Swap the red and blue pixels (RGB <-> BGR)
|
||||
/// \param input: Tensor of shape <H,W,3> and any OpenCv compatible type, see CVTensor.
|
||||
/// \param output: Swapped image of same shape and type
|
||||
|
|
|
@ -37,10 +37,12 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
std::vector<int64_t> label_shape = input.at(1)->shape().AsVector();
|
||||
|
||||
// Check inputs
|
||||
if (label_shape.size() != 2 || image_shape.size() != 4 || image_shape[0] != label_shape[0]) {
|
||||
if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) {
|
||||
RETURN_STATUS_UNEXPECTED("You must batch before calling MixUpBatch");
|
||||
}
|
||||
|
||||
if (label_shape.size() != 2) {
|
||||
RETURN_STATUS_UNEXPECTED("MixUpBatch: Label's must be in one-hot format and in a batch");
|
||||
}
|
||||
if ((image_shape[1] != 1 && image_shape[1] != 3) && (image_shape[3] != 1 && image_shape[3] != 3)) {
|
||||
RETURN_STATUS_UNEXPECTED("MixUpBatch: Images must be in the shape of HWC or CHW");
|
||||
}
|
||||
|
|
|
@ -94,6 +94,7 @@ constexpr char kAutoContrastOp[] = "AutoContrastOp";
|
|||
constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp";
|
||||
constexpr char kDecodeOp[] = "DecodeOp";
|
||||
constexpr char kCenterCropOp[] = "CenterCropOp";
|
||||
constexpr char kCutMixBatchOp[] = "CutMixBatchOp";
|
||||
constexpr char kCutOutOp[] = "CutOutOp";
|
||||
constexpr char kCropOp[] = "CropOp";
|
||||
constexpr char kEqualizeOp[] = "EqualizeOp";
|
||||
|
|
|
@ -43,12 +43,13 @@ Examples:
|
|||
import numbers
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
from .utils import Inter, Border
|
||||
from .utils import Inter, Border, ImageBatchFormat
|
||||
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
|
||||
check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
|
||||
check_range, check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, \
|
||||
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
|
||||
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER
|
||||
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \
|
||||
check_cut_mix_batch_c
|
||||
|
||||
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
|
||||
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
|
||||
|
@ -59,6 +60,8 @@ DE_C_BORDER_TYPE = {Border.CONSTANT: cde.BorderType.DE_BORDER_CONSTANT,
|
|||
Border.REFLECT: cde.BorderType.DE_BORDER_REFLECT,
|
||||
Border.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC}
|
||||
|
||||
DE_C_IMAGE_BATCH_FORMAT = {ImageBatchFormat.NHWC: cde.ImageBatchFormat.DE_IMAGE_BATCH_FORMAT_NHWC,
|
||||
ImageBatchFormat.NCHW: cde.ImageBatchFormat.DE_IMAGE_BATCH_FORMAT_NCHW}
|
||||
|
||||
def parse_padding(padding):
|
||||
if isinstance(padding, numbers.Number):
|
||||
|
@ -142,6 +145,33 @@ class Decode(cde.DecodeOp):
|
|||
super().__init__(self.rgb)
|
||||
|
||||
|
||||
class CutMixBatch(cde.CutMixBatchOp):
|
||||
"""
|
||||
Apply CutMix transformation on input batch of images and labels.
|
||||
Note that you need to make labels into one-hot format and batch before calling this function.
|
||||
|
||||
Args:
|
||||
image_batch_format (Image Batch Format): The method of padding. Can be any of
|
||||
[ImageBatchFormat.NHWC, ImageBatchFormat.NCHW]
|
||||
alpha (float): hyperparameter of beta distribution (default = 1.0).
|
||||
prob (float): The probability by which CutMix is applied to each image (default = 1.0).
|
||||
|
||||
Examples:
|
||||
>>> one_hot_op = data.OneHot(num_classes=10)
|
||||
>>> data = data.map(input_columns=["label"], operations=one_hot_op)
|
||||
>>> cutmix_batch_op = vision.CutMixBatch(ImageBatchFormat.NHWC, 1.0, 0.5)
|
||||
>>> data = data.batch(5)
|
||||
>>> data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op)
|
||||
"""
|
||||
|
||||
@check_cut_mix_batch_c
|
||||
def __init__(self, image_batch_format, alpha=1.0, prob=1.0):
|
||||
self.image_batch_format = image_batch_format.value
|
||||
self.alpha = alpha
|
||||
self.prob = prob
|
||||
super().__init__(DE_C_IMAGE_BATCH_FORMAT[image_batch_format], alpha, prob)
|
||||
|
||||
|
||||
class CutOut(cde.CutOutOp):
|
||||
"""
|
||||
Randomly cut (mask) out a given number of square patches from the input Numpy image array.
|
||||
|
|
|
@ -30,3 +30,9 @@ class Border(str, Enum):
|
|||
EDGE: str = "edge"
|
||||
REFLECT: str = "reflect"
|
||||
SYMMETRIC: str = "symmetric"
|
||||
|
||||
|
||||
# Image Batch Format
|
||||
class ImageBatchFormat(IntEnum):
|
||||
NHWC = 0
|
||||
NCHW = 1
|
||||
|
|
|
@ -19,7 +19,7 @@ from functools import wraps
|
|||
import numpy as np
|
||||
from mindspore._c_dataengine import TensorOp
|
||||
|
||||
from .utils import Inter, Border
|
||||
from .utils import Inter, Border, ImageBatchFormat
|
||||
from ...core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
|
||||
check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \
|
||||
check_tensor_op, UINT8_MAX
|
||||
|
@ -37,6 +37,20 @@ def check_crop_size(size):
|
|||
raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
|
||||
|
||||
|
||||
def check_cut_mix_batch_c(method):
|
||||
"""Wrapper method to check the parameters of CutMixBatch."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[image_batch_format, alpha, prob], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(image_batch_format, (ImageBatchFormat,), "image_batch_format")
|
||||
check_pos_float32(alpha)
|
||||
check_value(prob, [0, 1], "prob")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_resize_size(size):
|
||||
"""Wrapper method to check the parameters of resize."""
|
||||
if isinstance(size, int):
|
||||
|
|
|
@ -20,6 +20,7 @@ SET(DE_UT_SRCS
|
|||
circular_pool_test.cc
|
||||
client_config_test.cc
|
||||
connector_test.cc
|
||||
cutmix_batch_op_test.cc
|
||||
cut_out_op_test.cc
|
||||
datatype_test.cc
|
||||
decode_op_test.cc
|
||||
|
|
|
@ -25,6 +25,177 @@ class MindDataTestPipeline : public UT::DatasetOpTesting {
|
|||
protected:
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) {
|
||||
// Testing CutMixBatch on a batch of CHW images
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
int number_of_classes = 10;
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> hwc_to_chw = vision::HWC2CHW();
|
||||
EXPECT_NE(hwc_to_chw, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({hwc_to_chw},{"image"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 5;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(number_of_classes);
|
||||
EXPECT_NE(one_hot_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({one_hot_op},{"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNCHW, 1.0, 1.0);
|
||||
EXPECT_NE(cutmix_batch_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({cutmix_batch_op}, {"image", "label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
auto label = row["label"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
MS_LOG(INFO) << "Label shape: " << label->shape();
|
||||
EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 3 == image->shape()[1]
|
||||
&& 32 == image->shape()[2] && 32 == image->shape()[3], true);
|
||||
EXPECT_EQ(label->shape().AsVector().size() == 2 && batch_size == label->shape()[0] &&
|
||||
number_of_classes == label->shape()[1], true);
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) {
|
||||
// Calling CutMixBatch on a batch of HWC images with default values of alpha and prob
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
int number_of_classes = 10;
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 5;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(number_of_classes);
|
||||
EXPECT_NE(one_hot_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({one_hot_op},{"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC);
|
||||
EXPECT_NE(cutmix_batch_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({cutmix_batch_op}, {"image", "label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
auto label = row["label"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
MS_LOG(INFO) << "Label shape: " << label->shape();
|
||||
EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 32 == image->shape()[1]
|
||||
&& 32 == image->shape()[2] && 3 == image->shape()[3], true);
|
||||
EXPECT_EQ(label->shape().AsVector().size() == 2 && batch_size == label->shape()[0] &&
|
||||
number_of_classes == label->shape()[1], true);
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCutMixBatchFail1) {
|
||||
// Must fail because alpha can't be negative
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 5;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(10);
|
||||
EXPECT_NE(one_hot_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({one_hot_op},{"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, -1, 0.5);
|
||||
EXPECT_EQ(cutmix_batch_op, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) {
|
||||
// Must fail because prob can't be negative
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 5;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(10);
|
||||
EXPECT_NE(one_hot_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({one_hot_op},{"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, 1, -0.5);
|
||||
EXPECT_EQ(cutmix_batch_op, nullptr);
|
||||
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCutOut) {
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "common/cvop_common.h"
|
||||
#include "minddata/dataset/kernels/image/cutmix_batch_op.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
class MindDataTestCutMixBatchOp : public UT::CVOP::CVOpCommon {
|
||||
protected:
|
||||
MindDataTestCutMixBatchOp() : CVOpCommon() {}
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestCutMixBatchOp, TestSuccess1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestCutMixBatchOp success1 case";
|
||||
std::shared_ptr<Tensor> batched_tensor;
|
||||
std::shared_ptr<Tensor> batched_labels;
|
||||
Tensor::CreateEmpty(TensorShape({2, input_tensor_->shape()[0], input_tensor_->shape()[1], input_tensor_->shape()[2]}),
|
||||
input_tensor_->type(), &batched_tensor);
|
||||
for (int i = 0; i < 2; i++) {
|
||||
batched_tensor->InsertTensor({i}, input_tensor_);
|
||||
}
|
||||
Tensor::CreateFromVector(std::vector<uint32_t>({0, 1, 1, 0}), TensorShape({2, 2}), &batched_labels);
|
||||
std::shared_ptr<CutMixBatchOp> op = std::make_shared<CutMixBatchOp>(ImageBatchFormat::kNHWC, 1.0, 1.0);
|
||||
TensorRow in;
|
||||
in.push_back(batched_tensor);
|
||||
in.push_back(batched_labels);
|
||||
TensorRow out;
|
||||
ASSERT_TRUE(op->Compute(in, &out).IsOk());
|
||||
|
||||
EXPECT_EQ(in.at(0)->shape()[0], out.at(0)->shape()[0]);
|
||||
EXPECT_EQ(in.at(0)->shape()[1], out.at(0)->shape()[1]);
|
||||
EXPECT_EQ(in.at(0)->shape()[2], out.at(0)->shape()[2]);
|
||||
EXPECT_EQ(in.at(0)->shape()[3], out.at(0)->shape()[3]);
|
||||
|
||||
EXPECT_EQ(in.at(1)->shape()[0], out.at(1)->shape()[0]);
|
||||
EXPECT_EQ(in.at(1)->shape()[1], out.at(1)->shape()[1]);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestCutMixBatchOp, TestSuccess2) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestCutMixBatchOp success2 case";
|
||||
std::shared_ptr<Tensor> batched_tensor;
|
||||
std::shared_ptr<Tensor> batched_labels;
|
||||
std::shared_ptr<Tensor> chw_tensor;
|
||||
ASSERT_TRUE(HwcToChw(input_tensor_, &chw_tensor).IsOk());
|
||||
Tensor::CreateEmpty(TensorShape({2, chw_tensor->shape()[0], chw_tensor->shape()[1], chw_tensor->shape()[2]}),
|
||||
chw_tensor->type(), &batched_tensor);
|
||||
for (int i = 0; i < 2; i++) {
|
||||
batched_tensor->InsertTensor({i}, chw_tensor);
|
||||
}
|
||||
Tensor::CreateFromVector(std::vector<uint32_t>({0, 1, 1, 0}), TensorShape({2, 2}), &batched_labels);
|
||||
std::shared_ptr<CutMixBatchOp> op = std::make_shared<CutMixBatchOp>(ImageBatchFormat::kNCHW, 1.0, 0.5);
|
||||
TensorRow in;
|
||||
in.push_back(batched_tensor);
|
||||
in.push_back(batched_labels);
|
||||
TensorRow out;
|
||||
ASSERT_TRUE(op->Compute(in, &out).IsOk());
|
||||
|
||||
EXPECT_EQ(in.at(0)->shape()[0], out.at(0)->shape()[0]);
|
||||
EXPECT_EQ(in.at(0)->shape()[1], out.at(0)->shape()[1]);
|
||||
EXPECT_EQ(in.at(0)->shape()[2], out.at(0)->shape()[2]);
|
||||
EXPECT_EQ(in.at(0)->shape()[3], out.at(0)->shape()[3]);
|
||||
|
||||
EXPECT_EQ(in.at(1)->shape()[0], out.at(1)->shape()[0]);
|
||||
EXPECT_EQ(in.at(1)->shape()[1], out.at(1)->shape()[1]);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestCutMixBatchOp, TestFail1) {
|
||||
// This is a fail case because our labels are not batched and are 1-dimensional
|
||||
MS_LOG(INFO) << "Doing MindDataTestCutMixBatchOp fail1 case";
|
||||
std::shared_ptr<Tensor> labels;
|
||||
Tensor::CreateFromVector(std::vector<uint32_t>({0, 1, 1, 0}), TensorShape({4}), &labels);
|
||||
std::shared_ptr<CutMixBatchOp> op = std::make_shared<CutMixBatchOp>(ImageBatchFormat::kNHWC, 1.0, 1.0);
|
||||
TensorRow in;
|
||||
in.push_back(input_tensor_);
|
||||
in.push_back(labels);
|
||||
TensorRow out;
|
||||
ASSERT_FALSE(op->Compute(in, &out).IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestCutMixBatchOp, TestFail2) {
|
||||
// This should fail because the image_batch_format provided is not the same as the actual format of the images
|
||||
MS_LOG(INFO) << "Doing MindDataTestCutMixBatchOp fail2 case";
|
||||
std::shared_ptr<Tensor> batched_tensor;
|
||||
std::shared_ptr<Tensor> batched_labels;
|
||||
Tensor::CreateEmpty(TensorShape({2, input_tensor_->shape()[0], input_tensor_->shape()[1], input_tensor_->shape()[2]}),
|
||||
input_tensor_->type(), &batched_tensor);
|
||||
for (int i = 0; i < 2; i++) {
|
||||
batched_tensor->InsertTensor({i}, input_tensor_);
|
||||
}
|
||||
Tensor::CreateFromVector(std::vector<uint32_t>({0, 1, 1, 0}), TensorShape({2, 2}), &batched_labels);
|
||||
std::shared_ptr<CutMixBatchOp> op = std::make_shared<CutMixBatchOp>(ImageBatchFormat::kNCHW, 1.0, 1.0);
|
||||
TensorRow in;
|
||||
in.push_back(batched_tensor);
|
||||
in.push_back(batched_labels);
|
||||
TensorRow out;
|
||||
ASSERT_FALSE(op->Compute(in, &out).IsOk());
|
||||
}
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,336 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing the CutMixBatch op in DE
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
import mindspore.dataset.transforms.c_transforms as data_trans
|
||||
import mindspore.dataset.transforms.vision.utils as mode
|
||||
from mindspore import log as logger
|
||||
from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
|
||||
config_get_set_num_parallel_workers
|
||||
|
||||
DATA_DIR = "../data/dataset/testCifar10Data"
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
|
||||
def test_cutmix_batch_success1(plot=False):
|
||||
"""
|
||||
Test CutMixBatch op with specified alpha and prob parameters on a batch of CHW images
|
||||
"""
|
||||
logger.info("test_cutmix_batch_success1")
|
||||
|
||||
# Original Images
|
||||
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
ds_original = ds_original.batch(5, drop_remainder=True)
|
||||
|
||||
images_original = None
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = image
|
||||
else:
|
||||
images_original = np.append(images_original, image, axis=0)
|
||||
|
||||
# CutMix Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
hwc2chw_op = vision.HWC2CHW()
|
||||
data1 = data1.map(input_columns=["image"], operations=hwc2chw_op)
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5)
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
|
||||
|
||||
images_cutmix = None
|
||||
for idx, (image, _) in enumerate(data1):
|
||||
if idx == 0:
|
||||
images_cutmix = image.transpose(0, 2, 3, 1)
|
||||
else:
|
||||
images_cutmix = np.append(images_cutmix, image.transpose(0, 2, 3, 1), axis=0)
|
||||
if plot:
|
||||
visualize_list(images_original, images_cutmix)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_cutmix[i], images_original[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
|
||||
def test_cutmix_batch_success2(plot=False):
|
||||
"""
|
||||
Test CutMixBatch op with default values for alpha and prob on a batch of HWC images
|
||||
"""
|
||||
logger.info("test_cutmix_batch_success2")
|
||||
|
||||
# Original Images
|
||||
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
ds_original = ds_original.batch(5, drop_remainder=True)
|
||||
|
||||
images_original = None
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = image
|
||||
else:
|
||||
images_original = np.append(images_original, image, axis=0)
|
||||
|
||||
# CutMix Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
|
||||
|
||||
images_cutmix = None
|
||||
for idx, (image, _) in enumerate(data1):
|
||||
if idx == 0:
|
||||
images_cutmix = image
|
||||
else:
|
||||
images_cutmix = np.append(images_cutmix, image, axis=0)
|
||||
if plot:
|
||||
visualize_list(images_original, images_cutmix)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_cutmix[i], images_original[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
|
||||
def test_cutmix_batch_nhwc_md5():
|
||||
"""
|
||||
Test CutMixBatch on a batch of HWC images with MD5:
|
||||
"""
|
||||
logger.info("test_cutmix_batch_nhwc_md5")
|
||||
original_seed = config_get_set_seed(0)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
# CutMixBatch Images
|
||||
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data = data.map(input_columns=["label"], operations=one_hot_op)
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||
data = data.batch(5, drop_remainder=True)
|
||||
data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op)
|
||||
|
||||
filename = "cutmix_batch_c_nhwc_result.npz"
|
||||
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
# Restore config setting
|
||||
ds.config.set_seed(original_seed)
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_cutmix_batch_nchw_md5():
|
||||
"""
|
||||
Test CutMixBatch on a batch of CHW images with MD5:
|
||||
"""
|
||||
logger.info("test_cutmix_batch_nchw_md5")
|
||||
original_seed = config_get_set_seed(0)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
# CutMixBatch Images
|
||||
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
hwc2chw_op = vision.HWC2CHW()
|
||||
data = data.map(input_columns=["image"], operations=hwc2chw_op)
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data = data.map(input_columns=["label"], operations=one_hot_op)
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
|
||||
data = data.batch(5, drop_remainder=True)
|
||||
data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op)
|
||||
|
||||
filename = "cutmix_batch_c_nchw_result.npz"
|
||||
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
# Restore config setting
|
||||
ds.config.set_seed(original_seed)
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_cutmix_batch_fail1():
|
||||
"""
|
||||
Test CutMixBatch Fail 1
|
||||
We expect this to fail because the images and labels are not batched
|
||||
"""
|
||||
logger.info("test_cutmix_batch_fail1")
|
||||
|
||||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||
with pytest.raises(RuntimeError) as error:
|
||||
data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
|
||||
for idx, (image, _) in enumerate(data1):
|
||||
if idx == 0:
|
||||
images_cutmix = image
|
||||
else:
|
||||
images_cutmix = np.append(images_cutmix, image, axis=0)
|
||||
error_message = "You must batch before calling CutMixBatch"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
def test_cutmix_batch_fail2():
|
||||
"""
|
||||
Test CutMixBatch Fail 2
|
||||
We expect this to fail because alpha is negative
|
||||
"""
|
||||
logger.info("test_cutmix_batch_fail2")
|
||||
|
||||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
with pytest.raises(ValueError) as error:
|
||||
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1)
|
||||
error_message = "Input is not within the required interval"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
def test_cutmix_batch_fail3():
|
||||
"""
|
||||
Test CutMixBatch Fail 2
|
||||
We expect this to fail because prob is larger than 1
|
||||
"""
|
||||
logger.info("test_cutmix_batch_fail3")
|
||||
|
||||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
with pytest.raises(ValueError) as error:
|
||||
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2)
|
||||
error_message = "Input is not within the required interval"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
def test_cutmix_batch_fail4():
|
||||
"""
|
||||
Test CutMixBatch Fail 2
|
||||
We expect this to fail because prob is negative
|
||||
"""
|
||||
logger.info("test_cutmix_batch_fail4")
|
||||
|
||||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
with pytest.raises(ValueError) as error:
|
||||
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1)
|
||||
error_message = "Input is not within the required interval"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
def test_cutmix_batch_fail5():
|
||||
"""
|
||||
Test CutMixBatch op
|
||||
We expect this to fail because label column is not passed to cutmix_batch
|
||||
"""
|
||||
logger.info("test_cutmix_batch_fail5")
|
||||
|
||||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
data1 = data1.map(input_columns=["image"], operations=cutmix_batch_op)
|
||||
|
||||
with pytest.raises(RuntimeError) as error:
|
||||
images_cutmix = np.array([])
|
||||
for idx, (image, _) in enumerate(data1):
|
||||
if idx == 0:
|
||||
images_cutmix = image
|
||||
else:
|
||||
images_cutmix = np.append(images_cutmix, image, axis=0)
|
||||
error_message = "Both images and labels columns are required"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
def test_cutmix_batch_fail6():
|
||||
"""
|
||||
Test CutMixBatch op
|
||||
We expect this to fail because image_batch_format passed to CutMixBatch doesn't match the format of the images
|
||||
"""
|
||||
logger.info("test_cutmix_batch_fail6")
|
||||
|
||||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
|
||||
|
||||
with pytest.raises(RuntimeError) as error:
|
||||
images_cutmix = np.array([])
|
||||
for idx, (image, _) in enumerate(data1):
|
||||
if idx == 0:
|
||||
images_cutmix = image
|
||||
else:
|
||||
images_cutmix = np.append(images_cutmix, image, axis=0)
|
||||
error_message = "CutMixBatch: Image doesn't match the given image format."
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
def test_cutmix_batch_fail7():
|
||||
"""
|
||||
Test CutMixBatch op
|
||||
We expect this to fail because labels are not in one-hot format
|
||||
"""
|
||||
logger.info("test_cutmix_batch_fail7")
|
||||
|
||||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
|
||||
|
||||
with pytest.raises(RuntimeError) as error:
|
||||
images_cutmix = np.array([])
|
||||
for idx, (image, _) in enumerate(data1):
|
||||
if idx == 0:
|
||||
images_cutmix = image
|
||||
else:
|
||||
images_cutmix = np.append(images_cutmix, image, axis=0)
|
||||
error_message = "CutMixBatch: Label's must be in one-hot format and in a batch"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cutmix_batch_success1(plot=True)
|
||||
test_cutmix_batch_success2(plot=True)
|
||||
test_cutmix_batch_nchw_md5()
|
||||
test_cutmix_batch_nhwc_md5()
|
||||
test_cutmix_batch_fail1()
|
||||
test_cutmix_batch_fail2()
|
||||
test_cutmix_batch_fail3()
|
||||
test_cutmix_batch_fail4()
|
||||
test_cutmix_batch_fail5()
|
||||
test_cutmix_batch_fail6()
|
||||
test_cutmix_batch_fail7()
|
Loading…
Reference in New Issue