!45156 [MD] Support float label in MixUpBatch and CutMixBatch
Merge pull request !45156 from xiaotianci/mix_float_label
This commit is contained in:
commit
72389ba6d3
|
@ -75,11 +75,10 @@ Status CutMixBatchOp::ValidateCutMixBatch(const TensorRow &input) {
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED(
|
||||||
"CutMixBatch: please make sure images are <H,W,C> or <C,H,W> format, and batched before calling CutMixBatch.");
|
"CutMixBatch: please make sure images are <H,W,C> or <C,H,W> format, and batched before calling CutMixBatch.");
|
||||||
}
|
}
|
||||||
if (!input.at(1)->type().IsInt()) {
|
|
||||||
RETURN_STATUS_UNEXPECTED(
|
CHECK_FAIL_RETURN_UNEXPECTED(input.at(1)->type().IsNumeric(),
|
||||||
"CutMixBatch: Wrong labels type. The second column (labels) must only include int types, but got:" +
|
"CutMixBatch: invalid label type, label must be in a numeric type, but got: " +
|
||||||
input.at(1)->type().ToString());
|
input.at(1)->type().ToString() + ". You may need to perform OneHot first.");
|
||||||
}
|
|
||||||
if (label_shape.size() != kMinLabelShapeSize && label_shape.size() != kMaxLabelShapeSize) {
|
if (label_shape.size() != kMinLabelShapeSize && label_shape.size() != kMaxLabelShapeSize) {
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED(
|
||||||
"CutMixBatch: wrong labels shape. "
|
"CutMixBatch: wrong labels shape. "
|
||||||
|
@ -105,19 +104,19 @@ Status CutMixBatchOp::ValidateCutMixBatch(const TensorRow &input) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CutMixBatchOp::ComputeImage(const TensorRow &input, const int64_t rand_indx_i, const float lam, float *label_lam,
|
Status CutMixBatchOp::ComputeImage(const std::shared_ptr<Tensor> &image, int64_t rand_indx_i, float lam,
|
||||||
std::shared_ptr<Tensor> *image_i) {
|
float *label_lam, std::shared_ptr<Tensor> *image_i) {
|
||||||
std::vector<int64_t> image_shape = input.at(0)->shape().AsVector();
|
std::vector<int64_t> image_shape = image->shape().AsVector();
|
||||||
int x, y, crop_width, crop_height;
|
int x, y, crop_width, crop_height;
|
||||||
// Get a random image
|
// Get a random image
|
||||||
TensorShape remaining({-1});
|
TensorShape remaining({-1});
|
||||||
uchar *start_addr_of_index = nullptr;
|
uchar *start_addr_of_index = nullptr;
|
||||||
std::shared_ptr<Tensor> rand_image;
|
std::shared_ptr<Tensor> rand_image;
|
||||||
|
|
||||||
RETURN_IF_NOT_OK(input.at(0)->StartAddrOfIndex({rand_indx_i, 0, 0, 0}, &start_addr_of_index, &remaining));
|
RETURN_IF_NOT_OK(image->StartAddrOfIndex({rand_indx_i, 0, 0, 0}, &start_addr_of_index, &remaining));
|
||||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(
|
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(
|
||||||
TensorShape({image_shape[kDimensionOne], image_shape[kDimensionTwo], image_shape[kDimensionThree]}),
|
TensorShape({image_shape[kDimensionOne], image_shape[kDimensionTwo], image_shape[kDimensionThree]}), image->type(),
|
||||||
input.at(0)->type(), start_addr_of_index, &rand_image));
|
start_addr_of_index, &rand_image));
|
||||||
|
|
||||||
// Compute image
|
// Compute image
|
||||||
if (image_batch_format_ == ImageBatchFormat::kNHWC) {
|
if (image_batch_format_ == ImageBatchFormat::kNHWC) {
|
||||||
|
@ -154,30 +153,22 @@ Status CutMixBatchOp::ComputeImage(const TensorRow &input, const int64_t rand_in
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CutMixBatchOp::ComputeLabel(const TensorRow &input, const int64_t rand_indx_i, const int64_t index_i,
|
Status CutMixBatchOp::ComputeLabel(const std::shared_ptr<Tensor> &label, int64_t rand_indx_i, int64_t index_i,
|
||||||
const int64_t row_labels, const int64_t num_classes,
|
int64_t row_labels, int64_t num_classes, std::size_t label_shape_size,
|
||||||
const std::size_t label_shape_size, const float label_lam,
|
float label_lam, std::shared_ptr<Tensor> *out_labels) {
|
||||||
std::shared_ptr<Tensor> *out_labels) {
|
|
||||||
// Compute labels
|
// Compute labels
|
||||||
|
std::shared_ptr<Tensor> float_label;
|
||||||
|
RETURN_IF_NOT_OK(TypeCast(label, &float_label, DataType(DataType::DE_FLOAT32)));
|
||||||
for (int64_t j = 0; j < row_labels; j++) {
|
for (int64_t j = 0; j < row_labels; j++) {
|
||||||
for (int64_t k = 0; k < num_classes; k++) {
|
for (int64_t k = 0; k < num_classes; k++) {
|
||||||
std::vector<int64_t> first_index =
|
std::vector<int64_t> first_index =
|
||||||
label_shape_size == kMaxLabelShapeSize ? std::vector{index_i, j, k} : std::vector{index_i, k};
|
label_shape_size == kMaxLabelShapeSize ? std::vector{index_i, j, k} : std::vector{index_i, k};
|
||||||
std::vector<int64_t> second_index =
|
std::vector<int64_t> second_index =
|
||||||
label_shape_size == kMaxLabelShapeSize ? std::vector{rand_indx_i, j, k} : std::vector{rand_indx_i, k};
|
label_shape_size == kMaxLabelShapeSize ? std::vector{rand_indx_i, j, k} : std::vector{rand_indx_i, k};
|
||||||
if (input.at(1)->type().IsSignedInt()) {
|
float first_value, second_value;
|
||||||
int64_t first_value, second_value;
|
RETURN_IF_NOT_OK(float_label->GetItemAt(&first_value, first_index));
|
||||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index));
|
RETURN_IF_NOT_OK(float_label->GetItemAt(&second_value, second_index));
|
||||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index));
|
RETURN_IF_NOT_OK((*out_labels)->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value));
|
||||||
RETURN_IF_NOT_OK(
|
|
||||||
(*out_labels)->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value));
|
|
||||||
} else {
|
|
||||||
uint64_t first_value, second_value;
|
|
||||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index));
|
|
||||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index));
|
|
||||||
RETURN_IF_NOT_OK(
|
|
||||||
(*out_labels)->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,7 +199,7 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||||
|
|
||||||
// Tensor holding the output labels
|
// Tensor holding the output labels
|
||||||
std::shared_ptr<Tensor> out_labels;
|
std::shared_ptr<Tensor> out_labels;
|
||||||
RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), &out_labels, DataType(DataType::DE_FLOAT32)));
|
RETURN_IF_NOT_OK(TypeCast(input.at(1), &out_labels, DataType(DataType::DE_FLOAT32)));
|
||||||
int64_t row_labels = label_shape.size() == kValueThree ? label_shape[kDimensionOne] : kValueOne;
|
int64_t row_labels = label_shape.size() == kValueThree ? label_shape[kDimensionOne] : kValueOne;
|
||||||
int64_t num_classes = label_shape.size() == kValueThree ? label_shape[kDimensionTwo] : label_shape[kDimensionOne];
|
int64_t num_classes = label_shape.size() == kValueThree ? label_shape[kDimensionTwo] : label_shape[kDimensionOne];
|
||||||
|
|
||||||
|
@ -227,9 +218,9 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||||
if (random_number < prob_) {
|
if (random_number < prob_) {
|
||||||
float label_lam; // lambda used for labels
|
float label_lam; // lambda used for labels
|
||||||
// Compute image
|
// Compute image
|
||||||
RETURN_IF_NOT_OK(ComputeImage(input, rand_indx[i], lam, &label_lam, &images[i]));
|
RETURN_IF_NOT_OK(ComputeImage(input.at(0), rand_indx[i], lam, &label_lam, &images[i]));
|
||||||
// Compute labels
|
// Compute labels
|
||||||
RETURN_IF_NOT_OK(ComputeLabel(input, rand_indx[i], static_cast<int64_t>(i), row_labels, num_classes,
|
RETURN_IF_NOT_OK(ComputeLabel(input.at(1), rand_indx[i], static_cast<int64_t>(i), row_labels, num_classes,
|
||||||
label_shape.size(), label_lam, &out_labels));
|
label_shape.size(), label_lam, &out_labels));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,7 +54,7 @@ class CutMixBatchOp : public TensorOp {
|
||||||
/// \param[in] label_lam Lambda used for labels, will be updated after computing each image.
|
/// \param[in] label_lam Lambda used for labels, will be updated after computing each image.
|
||||||
/// \param[in] image_i The result of the i-th computed image.
|
/// \param[in] image_i The result of the i-th computed image.
|
||||||
/// \returns Status
|
/// \returns Status
|
||||||
Status ComputeImage(const TensorRow &input, const int64_t rand_indx_i, const float lam, float *label_lam,
|
Status ComputeImage(const std::shared_ptr<Tensor> &image, int64_t rand_indx_i, float lam, float *label_lam,
|
||||||
std::shared_ptr<Tensor> *image_i);
|
std::shared_ptr<Tensor> *image_i);
|
||||||
|
|
||||||
/// \brief Helper function used in Compute to compute each label corresponding to each image.
|
/// \brief Helper function used in Compute to compute each label corresponding to each image.
|
||||||
|
@ -67,9 +67,10 @@ class CutMixBatchOp : public TensorOp {
|
||||||
/// \param[in] label_lam Lambda used for setting the location.
|
/// \param[in] label_lam Lambda used for setting the location.
|
||||||
/// \param[in] out_labels The output of the i-th label, corresponding to the i-th computed image.
|
/// \param[in] out_labels The output of the i-th label, corresponding to the i-th computed image.
|
||||||
/// \returns Status
|
/// \returns Status
|
||||||
Status ComputeLabel(const TensorRow &input, const int64_t rand_indx_i, const int64_t index_i,
|
Status ComputeLabel(const std::shared_ptr<Tensor> &label, int64_t rand_indx_i, int64_t index_i, int64_t row_labels,
|
||||||
const int64_t row_labels, const int64_t num_classes, const std::size_t label_shape_size,
|
int64_t num_classes, std::size_t label_shape_size, float label_lam,
|
||||||
const float label_lam, std::shared_ptr<Tensor> *out_labels);
|
std::shared_ptr<Tensor> *out_labels);
|
||||||
|
|
||||||
float alpha_;
|
float alpha_;
|
||||||
float prob_;
|
float prob_;
|
||||||
ImageBatchFormat image_batch_format_;
|
ImageBatchFormat image_batch_format_;
|
||||||
|
|
|
@ -39,9 +39,9 @@ constexpr int64_t value_three = 3;
|
||||||
|
|
||||||
MixUpBatchOp::MixUpBatchOp(float alpha) : alpha_(alpha) { rnd_.seed(GetSeed()); }
|
MixUpBatchOp::MixUpBatchOp(float alpha) : alpha_(alpha) { rnd_.seed(GetSeed()); }
|
||||||
|
|
||||||
Status MixUpBatchOp::ComputeLabels(const TensorRow &input, std::shared_ptr<Tensor> *out_labels,
|
Status MixUpBatchOp::ComputeLabels(const std::shared_ptr<Tensor> &label, std::shared_ptr<Tensor> *out_labels,
|
||||||
std::vector<int64_t> *rand_indx, const std::vector<int64_t> &label_shape,
|
std::vector<int64_t> *rand_indx, const std::vector<int64_t> &label_shape, float lam,
|
||||||
const float lam, const size_t images_size) {
|
size_t images_size) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||||
images_size <= static_cast<size_t>(std::numeric_limits<int64_t>::max()),
|
images_size <= static_cast<size_t>(std::numeric_limits<int64_t>::max()),
|
||||||
"The \'images_size\' must not be more than \'INT64_MAX\', but got: " + std::to_string(images_size));
|
"The \'images_size\' must not be more than \'INT64_MAX\', but got: " + std::to_string(images_size));
|
||||||
|
@ -50,7 +50,9 @@ Status MixUpBatchOp::ComputeLabels(const TensorRow &input, std::shared_ptr<Tenso
|
||||||
}
|
}
|
||||||
std::shuffle(rand_indx->begin(), rand_indx->end(), rnd_);
|
std::shuffle(rand_indx->begin(), rand_indx->end(), rnd_);
|
||||||
|
|
||||||
RETURN_IF_NOT_OK(TypeCast(input.at(1), out_labels, DataType(DataType::DE_FLOAT32)));
|
std::shared_ptr<Tensor> float_label;
|
||||||
|
RETURN_IF_NOT_OK(TypeCast(label, &float_label, DataType(DataType::DE_FLOAT32)));
|
||||||
|
RETURN_IF_NOT_OK(TypeCast(label, out_labels, DataType(DataType::DE_FLOAT32)));
|
||||||
|
|
||||||
int64_t row_labels = label_shape.size() == kMaxLabelShapeSize ? label_shape[1] : 1;
|
int64_t row_labels = label_shape.size() == kMaxLabelShapeSize ? label_shape[1] : 1;
|
||||||
int64_t num_classes = label_shape.size() == kMaxLabelShapeSize ? label_shape[dimension_two] : label_shape[1];
|
int64_t num_classes = label_shape.size() == kMaxLabelShapeSize ? label_shape[dimension_two] : label_shape[1];
|
||||||
|
@ -63,17 +65,10 @@ Status MixUpBatchOp::ComputeLabels(const TensorRow &input, std::shared_ptr<Tenso
|
||||||
std::vector<int64_t> second_index = label_shape.size() == kMaxLabelShapeSize
|
std::vector<int64_t> second_index = label_shape.size() == kMaxLabelShapeSize
|
||||||
? std::vector{(*rand_indx)[static_cast<size_t>(i)], j, k}
|
? std::vector{(*rand_indx)[static_cast<size_t>(i)], j, k}
|
||||||
: std::vector{(*rand_indx)[static_cast<size_t>(i)], k};
|
: std::vector{(*rand_indx)[static_cast<size_t>(i)], k};
|
||||||
if (input.at(1)->type().IsSignedInt()) {
|
float first_value, second_value;
|
||||||
int64_t first_value, second_value;
|
RETURN_IF_NOT_OK(float_label->GetItemAt(&first_value, first_index));
|
||||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index));
|
RETURN_IF_NOT_OK(float_label->GetItemAt(&second_value, second_index));
|
||||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index));
|
RETURN_IF_NOT_OK((*out_labels)->SetItemAt(first_index, lam * first_value + (1 - lam) * second_value));
|
||||||
RETURN_IF_NOT_OK((*out_labels)->SetItemAt(first_index, lam * first_value + (1 - lam) * second_value));
|
|
||||||
} else {
|
|
||||||
uint64_t first_value, second_value;
|
|
||||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index));
|
|
||||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index));
|
|
||||||
RETURN_IF_NOT_OK((*out_labels)->SetItemAt(first_index, lam * first_value + (1 - lam) * second_value));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -96,12 +91,10 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||||
", but got: " + std::to_string(image_shape.size()) +
|
", but got: " + std::to_string(image_shape.size()) +
|
||||||
", make sure image shape are <H,W,C> or <C,H,W> and batched before calling MixUpBatch.");
|
", make sure image shape are <H,W,C> or <C,H,W> and batched before calling MixUpBatch.");
|
||||||
}
|
}
|
||||||
if (!input.at(1)->type().IsInt()) {
|
|
||||||
RETURN_STATUS_UNEXPECTED(
|
|
||||||
"MixUpBatch: wrong labels type. The second column (labels) must only include int types, but got: " +
|
|
||||||
input.at(1)->type().ToString());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(input.at(1)->type().IsNumeric(),
|
||||||
|
"MixUpBatch: invalid label type, label must be in a numeric type, but got: " +
|
||||||
|
input.at(1)->type().ToString() + ". You may need to perform OneHot first.");
|
||||||
if (label_shape.size() != kMinLabelShapeSize && label_shape.size() != kMaxLabelShapeSize) {
|
if (label_shape.size() != kMinLabelShapeSize && label_shape.size() != kMaxLabelShapeSize) {
|
||||||
RETURN_STATUS_UNEXPECTED(
|
RETURN_STATUS_UNEXPECTED(
|
||||||
"MixUpBatch: wrong labels shape. "
|
"MixUpBatch: wrong labels shape. "
|
||||||
|
@ -137,7 +130,7 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||||
std::shared_ptr<Tensor> out_labels;
|
std::shared_ptr<Tensor> out_labels;
|
||||||
|
|
||||||
// Compute labels
|
// Compute labels
|
||||||
RETURN_IF_NOT_OK(ComputeLabels(input, &out_labels, &rand_indx, label_shape, lam, images.size()));
|
RETURN_IF_NOT_OK(ComputeLabels(input.at(1), &out_labels, &rand_indx, label_shape, lam, images.size()));
|
||||||
|
|
||||||
// Compute images
|
// Compute images
|
||||||
for (int64_t i = 0; i < images.size(); i++) {
|
for (int64_t i = 0; i < images.size(); i++) {
|
||||||
|
|
|
@ -43,8 +43,10 @@ class MixUpBatchOp : public TensorOp {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// a helper function to shorten the main Compute function
|
// a helper function to shorten the main Compute function
|
||||||
Status ComputeLabels(const TensorRow &input, std::shared_ptr<Tensor> *out_labels, std::vector<int64_t> *rand_indx,
|
Status ComputeLabels(const std::shared_ptr<Tensor> &label, std::shared_ptr<Tensor> *out_labels,
|
||||||
const std::vector<int64_t> &label_shape, const float lam, const size_t images_size);
|
std::vector<int64_t> *rand_indx, const std::vector<int64_t> &label_shape, float lam,
|
||||||
|
size_t images_size);
|
||||||
|
|
||||||
float alpha_;
|
float alpha_;
|
||||||
std::mt19937 rnd_;
|
std::mt19937 rnd_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -19,7 +19,7 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
import mindspore.dataset.vision as vision
|
import mindspore.dataset.vision as vision
|
||||||
import mindspore.dataset.transforms as data_trans
|
import mindspore.dataset.transforms as transforms
|
||||||
import mindspore.dataset.vision.utils as mode
|
import mindspore.dataset.vision.utils as mode
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
|
from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
|
||||||
|
@ -54,7 +54,7 @@ def test_cutmix_batch_success1(plot=False):
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
hwc2chw_op = vision.HWC2CHW()
|
hwc2chw_op = vision.HWC2CHW()
|
||||||
data1 = data1.map(operations=hwc2chw_op, input_columns=["image"])
|
data1 = data1.map(operations=hwc2chw_op, input_columns=["image"])
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5)
|
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5)
|
||||||
data1 = data1.batch(5, drop_remainder=True)
|
data1 = data1.batch(5, drop_remainder=True)
|
||||||
|
@ -97,7 +97,7 @@ def test_cutmix_batch_success2(plot=False):
|
||||||
|
|
||||||
# CutMix Images
|
# CutMix Images
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
rescale_op = vision.Rescale((1.0 / 255.0), 0.0)
|
rescale_op = vision.Rescale((1.0 / 255.0), 0.0)
|
||||||
data1 = data1.map(operations=rescale_op, input_columns=["image"])
|
data1 = data1.map(operations=rescale_op, input_columns=["image"])
|
||||||
|
@ -152,7 +152,7 @@ def test_cutmix_batch_success3(plot=False):
|
||||||
resize_op = vision.Resize([224, 224])
|
resize_op = vision.Resize([224, 224])
|
||||||
data1 = data1.map(operations=[resize_op], input_columns=["image"])
|
data1 = data1.map(operations=[resize_op], input_columns=["image"])
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
|
|
||||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||||
|
@ -206,7 +206,7 @@ def test_cutmix_batch_success4(plot=False):
|
||||||
resize_op = vision.Resize([224, 224])
|
resize_op = vision.Resize([224, 224])
|
||||||
data1 = data1.map(operations=[resize_op], input_columns=["image"])
|
data1 = data1.map(operations=[resize_op], input_columns=["image"])
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=100)
|
one_hot_op = transforms.OneHot(num_classes=100)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["attr"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["attr"])
|
||||||
|
|
||||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.5, 0.9)
|
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.5, 0.9)
|
||||||
|
@ -242,7 +242,7 @@ def test_cutmix_batch_nhwc_md5():
|
||||||
# CutMixBatch Images
|
# CutMixBatch Images
|
||||||
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data = data.map(operations=one_hot_op, input_columns=["label"])
|
data = data.map(operations=one_hot_op, input_columns=["label"])
|
||||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||||
data = data.batch(5, drop_remainder=True)
|
data = data.batch(5, drop_remainder=True)
|
||||||
|
@ -270,7 +270,7 @@ def test_cutmix_batch_nchw_md5():
|
||||||
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
hwc2chw_op = vision.HWC2CHW()
|
hwc2chw_op = vision.HWC2CHW()
|
||||||
data = data.map(operations=hwc2chw_op, input_columns=["image"])
|
data = data.map(operations=hwc2chw_op, input_columns=["image"])
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data = data.map(operations=one_hot_op, input_columns=["label"])
|
data = data.map(operations=one_hot_op, input_columns=["label"])
|
||||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
|
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
|
||||||
data = data.batch(5, drop_remainder=True)
|
data = data.batch(5, drop_remainder=True)
|
||||||
|
@ -284,6 +284,52 @@ def test_cutmix_batch_nchw_md5():
|
||||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cutmix_batch_float_label():
|
||||||
|
"""
|
||||||
|
Feature: CutMixBatch
|
||||||
|
Description: Test CutMixBatch with label in type of float
|
||||||
|
Expectation: Output is as expected
|
||||||
|
"""
|
||||||
|
original_seed = config_get_set_seed(0)
|
||||||
|
|
||||||
|
image = np.random.randint(0, 255, (3, 28, 28, 1), dtype=np.uint8)
|
||||||
|
label = np.random.randint(0, 5, (3, 1))
|
||||||
|
decode_label = transforms.OneHot(5)(label)
|
||||||
|
float_label = transforms.TypeCast(float)(decode_label)
|
||||||
|
_, mix_label = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)(image, float_label)
|
||||||
|
expected_label = np.array([[0., 0.9285714, 0., 0., 0.0714286],
|
||||||
|
[0., 0., 0., 0., 1.],
|
||||||
|
[0., 0.02040815, 0., 0., 0.97959185]])
|
||||||
|
np.testing.assert_almost_equal(mix_label, expected_label)
|
||||||
|
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cutmix_batch_twice():
|
||||||
|
"""
|
||||||
|
Feature: CutMixBatch
|
||||||
|
Description: Test CutMixBatch called twice
|
||||||
|
Expectation: Output is as expected
|
||||||
|
"""
|
||||||
|
original_seed = config_get_set_seed(5)
|
||||||
|
|
||||||
|
dataset = ds.Cifar10Dataset(DATA_DIR, num_samples=3, shuffle=False)
|
||||||
|
one_hot = transforms.OneHot(num_classes=10)
|
||||||
|
dataset = dataset.map(operations=one_hot, input_columns=["label"])
|
||||||
|
cut_mix_batch = vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 2.0, 0.5)
|
||||||
|
dataset = dataset.batch(3, drop_remainder=False)
|
||||||
|
dataset = dataset.map(operations=cut_mix_batch, input_columns=["image", "label"])
|
||||||
|
dataset = dataset.map(operations=cut_mix_batch, input_columns=["image", "label"])
|
||||||
|
|
||||||
|
expected_label = np.array([[0.58618164, 0.41107178, 0.00274658, 0., 0., 0., 0., 0., 0., 0.],
|
||||||
|
[0.00109863, 0.9766998, 0.02220154, 0., 0., 0., 0., 0., 0., 0.],
|
||||||
|
[0.15673828, 0.02197266, 0.82128906, 0., 0., 0., 0., 0., 0., 0.]])
|
||||||
|
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
np.testing.assert_almost_equal(item["label"], expected_label)
|
||||||
|
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
|
||||||
|
|
||||||
def test_cutmix_batch_fail1():
|
def test_cutmix_batch_fail1():
|
||||||
"""
|
"""
|
||||||
Feature: CutMixBatch op
|
Feature: CutMixBatch op
|
||||||
|
@ -295,7 +341,7 @@ def test_cutmix_batch_fail1():
|
||||||
# CutMixBatch Images
|
# CutMixBatch Images
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||||
with pytest.raises(RuntimeError) as error:
|
with pytest.raises(RuntimeError) as error:
|
||||||
|
@ -320,7 +366,7 @@ def test_cutmix_batch_fail2():
|
||||||
# CutMixBatch Images
|
# CutMixBatch Images
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
with pytest.raises(ValueError) as error:
|
with pytest.raises(ValueError) as error:
|
||||||
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1)
|
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1)
|
||||||
|
@ -339,7 +385,7 @@ def test_cutmix_batch_fail3():
|
||||||
# CutMixBatch Images
|
# CutMixBatch Images
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
with pytest.raises(ValueError) as error:
|
with pytest.raises(ValueError) as error:
|
||||||
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2)
|
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2)
|
||||||
|
@ -358,7 +404,7 @@ def test_cutmix_batch_fail4():
|
||||||
# CutMixBatch Images
|
# CutMixBatch Images
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
with pytest.raises(ValueError) as error:
|
with pytest.raises(ValueError) as error:
|
||||||
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1)
|
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1)
|
||||||
|
@ -377,7 +423,7 @@ def test_cutmix_batch_fail5():
|
||||||
# CutMixBatch Images
|
# CutMixBatch Images
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||||
data1 = data1.batch(5, drop_remainder=True)
|
data1 = data1.batch(5, drop_remainder=True)
|
||||||
|
@ -405,7 +451,7 @@ def test_cutmix_batch_fail6():
|
||||||
# CutMixBatch Images
|
# CutMixBatch Images
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
|
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
|
||||||
data1 = data1.batch(5, drop_remainder=True)
|
data1 = data1.batch(5, drop_remainder=True)
|
||||||
|
@ -459,7 +505,7 @@ def test_cutmix_batch_fail8():
|
||||||
# CutMixBatch Images
|
# CutMixBatch Images
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
with pytest.raises(ValueError) as error:
|
with pytest.raises(ValueError) as error:
|
||||||
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.0)
|
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.0)
|
||||||
|
@ -467,6 +513,20 @@ def test_cutmix_batch_fail8():
|
||||||
assert error_message in str(error.value)
|
assert error_message in str(error.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cut_mix_batch_invalid_label_type():
|
||||||
|
"""
|
||||||
|
Feature: CutMixBatch
|
||||||
|
Description: Test CutMixBatch with label in str type
|
||||||
|
Expectation: Error is raised as expected
|
||||||
|
"""
|
||||||
|
image = np.random.randint(0, 255, (3, 28, 28, 1), dtype=np.uint8)
|
||||||
|
label = np.array([["one"], ["two"], ["three"]])
|
||||||
|
with pytest.raises(RuntimeError) as error:
|
||||||
|
_ = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)(image, label)
|
||||||
|
error_message = "CutMixBatch: invalid label type, label must be in a numeric type"
|
||||||
|
assert error_message in str(error.value)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_cutmix_batch_success1(plot=True)
|
test_cutmix_batch_success1(plot=True)
|
||||||
test_cutmix_batch_success2(plot=True)
|
test_cutmix_batch_success2(plot=True)
|
||||||
|
|
|
@ -19,7 +19,7 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
import mindspore.dataset.vision as vision
|
import mindspore.dataset.vision as vision
|
||||||
import mindspore.dataset.transforms as data_trans
|
import mindspore.dataset.transforms as transforms
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
|
from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
|
||||||
config_get_set_num_parallel_workers
|
config_get_set_num_parallel_workers
|
||||||
|
@ -53,7 +53,7 @@ def test_mixup_batch_success1(plot=False):
|
||||||
# MixUp Images
|
# MixUp Images
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
mixup_batch_op = vision.MixUpBatch(2)
|
mixup_batch_op = vision.MixUpBatch(2)
|
||||||
data1 = data1.batch(5, drop_remainder=True)
|
data1 = data1.batch(5, drop_remainder=True)
|
||||||
|
@ -102,7 +102,7 @@ def test_mixup_batch_success2(plot=False):
|
||||||
decode_op = vision.Decode()
|
decode_op = vision.Decode()
|
||||||
data1 = data1.map(operations=[decode_op], input_columns=["image"])
|
data1 = data1.map(operations=[decode_op], input_columns=["image"])
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
|
|
||||||
mixup_batch_op = vision.MixUpBatch(2.0)
|
mixup_batch_op = vision.MixUpBatch(2.0)
|
||||||
|
@ -147,7 +147,7 @@ def test_mixup_batch_success3(plot=False):
|
||||||
# MixUp Images
|
# MixUp Images
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
mixup_batch_op = vision.MixUpBatch()
|
mixup_batch_op = vision.MixUpBatch()
|
||||||
data1 = data1.batch(5, drop_remainder=True)
|
data1 = data1.batch(5, drop_remainder=True)
|
||||||
|
@ -196,7 +196,7 @@ def test_mixup_batch_success4(plot=False):
|
||||||
decode_op = vision.Decode()
|
decode_op = vision.Decode()
|
||||||
data1 = data1.map(operations=[decode_op], input_columns=["image"])
|
data1 = data1.map(operations=[decode_op], input_columns=["image"])
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=100)
|
one_hot_op = transforms.OneHot(num_classes=100)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["attr"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["attr"])
|
||||||
|
|
||||||
mixup_batch_op = vision.MixUpBatch()
|
mixup_batch_op = vision.MixUpBatch()
|
||||||
|
@ -232,7 +232,7 @@ def test_mixup_batch_md5():
|
||||||
# MixUp Images
|
# MixUp Images
|
||||||
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data = data.map(operations=one_hot_op, input_columns=["label"])
|
data = data.map(operations=one_hot_op, input_columns=["label"])
|
||||||
mixup_batch_op = vision.MixUpBatch()
|
mixup_batch_op = vision.MixUpBatch()
|
||||||
data = data.batch(5, drop_remainder=True)
|
data = data.batch(5, drop_remainder=True)
|
||||||
|
@ -246,6 +246,52 @@ def test_mixup_batch_md5():
|
||||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mixup_batch_float_label():
|
||||||
|
"""
|
||||||
|
Feature: MixUpBatch
|
||||||
|
Description: Test MixUpBatch with label in type of float
|
||||||
|
Expectation: Output is as expected
|
||||||
|
"""
|
||||||
|
original_seed = config_get_set_seed(0)
|
||||||
|
|
||||||
|
image = np.random.randint(0, 255, (3, 28, 28, 1), dtype=np.uint8)
|
||||||
|
label = np.random.randint(0, 5, (3, 1))
|
||||||
|
decode_label = transforms.OneHot(5)(label)
|
||||||
|
float_label = transforms.TypeCast(float)(decode_label)
|
||||||
|
_, mix_label = vision.MixUpBatch()(image, float_label)
|
||||||
|
expected_label = np.array([[0., 0.6824126, 0., 0., 0.3175874],
|
||||||
|
[0., 0., 0., 0., 1.],
|
||||||
|
[0., 0.3175874, 0., 0., 0.6824126]])
|
||||||
|
np.testing.assert_almost_equal(mix_label, expected_label)
|
||||||
|
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mixup_batch_twice():
|
||||||
|
"""
|
||||||
|
Feature: MixUpBatch
|
||||||
|
Description: Test MixUpBatch called twice
|
||||||
|
Expectation: Output is as expected
|
||||||
|
"""
|
||||||
|
original_seed = config_get_set_seed(1)
|
||||||
|
|
||||||
|
dataset = ds.Cifar10Dataset(DATA_DIR, num_samples=3, shuffle=False)
|
||||||
|
one_hot = transforms.OneHot(num_classes=10)
|
||||||
|
dataset = dataset.map(operations=one_hot, input_columns=["label"])
|
||||||
|
mix_up_batch = vision.MixUpBatch()
|
||||||
|
dataset = dataset.batch(3, drop_remainder=False)
|
||||||
|
dataset = dataset.map(operations=mix_up_batch, input_columns=["image", "label"])
|
||||||
|
dataset = dataset.map(operations=mix_up_batch, input_columns=["image", "label"])
|
||||||
|
|
||||||
|
expected_label = np.array([[0.29373336, 0.49647674, 0.20978989, 0., 0., 0., 0., 0., 0., 0.],
|
||||||
|
[0.20978989, 0.29373336, 0.49647674, 0., 0., 0., 0., 0., 0., 0.],
|
||||||
|
[0.49647674, 0.20978989, 0.29373336, 0., 0., 0., 0., 0., 0., 0.]])
|
||||||
|
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
np.testing.assert_almost_equal(item["label"], expected_label)
|
||||||
|
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
|
||||||
|
|
||||||
def test_mixup_batch_fail1():
|
def test_mixup_batch_fail1():
|
||||||
"""
|
"""
|
||||||
Feature: MixUpBatch op
|
Feature: MixUpBatch op
|
||||||
|
@ -268,7 +314,7 @@ def test_mixup_batch_fail1():
|
||||||
# MixUp Images
|
# MixUp Images
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
mixup_batch_op = vision.MixUpBatch(0.1)
|
mixup_batch_op = vision.MixUpBatch(0.1)
|
||||||
with pytest.raises(RuntimeError) as error:
|
with pytest.raises(RuntimeError) as error:
|
||||||
|
@ -304,7 +350,7 @@ def test_mixup_batch_fail2():
|
||||||
# MixUp Images
|
# MixUp Images
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
with pytest.raises(ValueError) as error:
|
with pytest.raises(ValueError) as error:
|
||||||
vision.MixUpBatch(-1)
|
vision.MixUpBatch(-1)
|
||||||
|
@ -333,7 +379,7 @@ def test_mixup_batch_fail3():
|
||||||
# MixUp Images
|
# MixUp Images
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
mixup_batch_op = vision.MixUpBatch()
|
mixup_batch_op = vision.MixUpBatch()
|
||||||
data1 = data1.batch(5, drop_remainder=True)
|
data1 = data1.batch(5, drop_remainder=True)
|
||||||
|
@ -372,7 +418,7 @@ def test_mixup_batch_fail4():
|
||||||
# MixUp Images
|
# MixUp Images
|
||||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||||
|
|
||||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
one_hot_op = transforms.OneHot(num_classes=10)
|
||||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||||
with pytest.raises(ValueError) as error:
|
with pytest.raises(ValueError) as error:
|
||||||
vision.MixUpBatch(0.0)
|
vision.MixUpBatch(0.0)
|
||||||
|
@ -417,6 +463,20 @@ def test_mixup_batch_fail5():
|
||||||
assert error_message in str(error.value)
|
assert error_message in str(error.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mix_up_batch_invalid_label_type():
|
||||||
|
"""
|
||||||
|
Feature: MixUpBatch
|
||||||
|
Description: Test MixUpBatch with label in str type
|
||||||
|
Expectation: Error is raised as expected
|
||||||
|
"""
|
||||||
|
image = np.random.randint(0, 255, (3, 28, 28, 1), dtype=np.uint8)
|
||||||
|
label = np.array([["one"], ["two"], ["three"]])
|
||||||
|
with pytest.raises(RuntimeError) as error:
|
||||||
|
_ = vision.MixUpBatch()(image, label)
|
||||||
|
error_message = "MixUpBatch: invalid label type, label must be in a numeric type"
|
||||||
|
assert error_message in str(error.value)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_mixup_batch_success1(plot=True)
|
test_mixup_batch_success1(plot=True)
|
||||||
test_mixup_batch_success2(plot=True)
|
test_mixup_batch_success2(plot=True)
|
||||||
|
|
Loading…
Reference in New Issue