!45156 [MD] Support float label in MixUpBatch and CutMixBatch
Merge pull request !45156 from xiaotianci/mix_float_label
This commit is contained in:
commit
72389ba6d3
|
@ -75,11 +75,10 @@ Status CutMixBatchOp::ValidateCutMixBatch(const TensorRow &input) {
|
|||
RETURN_STATUS_UNEXPECTED(
|
||||
"CutMixBatch: please make sure images are <H,W,C> or <C,H,W> format, and batched before calling CutMixBatch.");
|
||||
}
|
||||
if (!input.at(1)->type().IsInt()) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"CutMixBatch: Wrong labels type. The second column (labels) must only include int types, but got:" +
|
||||
input.at(1)->type().ToString());
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input.at(1)->type().IsNumeric(),
|
||||
"CutMixBatch: invalid label type, label must be in a numeric type, but got: " +
|
||||
input.at(1)->type().ToString() + ". You may need to perform OneHot first.");
|
||||
if (label_shape.size() != kMinLabelShapeSize && label_shape.size() != kMaxLabelShapeSize) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"CutMixBatch: wrong labels shape. "
|
||||
|
@ -105,19 +104,19 @@ Status CutMixBatchOp::ValidateCutMixBatch(const TensorRow &input) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CutMixBatchOp::ComputeImage(const TensorRow &input, const int64_t rand_indx_i, const float lam, float *label_lam,
|
||||
std::shared_ptr<Tensor> *image_i) {
|
||||
std::vector<int64_t> image_shape = input.at(0)->shape().AsVector();
|
||||
Status CutMixBatchOp::ComputeImage(const std::shared_ptr<Tensor> &image, int64_t rand_indx_i, float lam,
|
||||
float *label_lam, std::shared_ptr<Tensor> *image_i) {
|
||||
std::vector<int64_t> image_shape = image->shape().AsVector();
|
||||
int x, y, crop_width, crop_height;
|
||||
// Get a random image
|
||||
TensorShape remaining({-1});
|
||||
uchar *start_addr_of_index = nullptr;
|
||||
std::shared_ptr<Tensor> rand_image;
|
||||
|
||||
RETURN_IF_NOT_OK(input.at(0)->StartAddrOfIndex({rand_indx_i, 0, 0, 0}, &start_addr_of_index, &remaining));
|
||||
RETURN_IF_NOT_OK(image->StartAddrOfIndex({rand_indx_i, 0, 0, 0}, &start_addr_of_index, &remaining));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(
|
||||
TensorShape({image_shape[kDimensionOne], image_shape[kDimensionTwo], image_shape[kDimensionThree]}),
|
||||
input.at(0)->type(), start_addr_of_index, &rand_image));
|
||||
TensorShape({image_shape[kDimensionOne], image_shape[kDimensionTwo], image_shape[kDimensionThree]}), image->type(),
|
||||
start_addr_of_index, &rand_image));
|
||||
|
||||
// Compute image
|
||||
if (image_batch_format_ == ImageBatchFormat::kNHWC) {
|
||||
|
@ -154,30 +153,22 @@ Status CutMixBatchOp::ComputeImage(const TensorRow &input, const int64_t rand_in
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CutMixBatchOp::ComputeLabel(const TensorRow &input, const int64_t rand_indx_i, const int64_t index_i,
|
||||
const int64_t row_labels, const int64_t num_classes,
|
||||
const std::size_t label_shape_size, const float label_lam,
|
||||
std::shared_ptr<Tensor> *out_labels) {
|
||||
Status CutMixBatchOp::ComputeLabel(const std::shared_ptr<Tensor> &label, int64_t rand_indx_i, int64_t index_i,
|
||||
int64_t row_labels, int64_t num_classes, std::size_t label_shape_size,
|
||||
float label_lam, std::shared_ptr<Tensor> *out_labels) {
|
||||
// Compute labels
|
||||
std::shared_ptr<Tensor> float_label;
|
||||
RETURN_IF_NOT_OK(TypeCast(label, &float_label, DataType(DataType::DE_FLOAT32)));
|
||||
for (int64_t j = 0; j < row_labels; j++) {
|
||||
for (int64_t k = 0; k < num_classes; k++) {
|
||||
std::vector<int64_t> first_index =
|
||||
label_shape_size == kMaxLabelShapeSize ? std::vector{index_i, j, k} : std::vector{index_i, k};
|
||||
std::vector<int64_t> second_index =
|
||||
label_shape_size == kMaxLabelShapeSize ? std::vector{rand_indx_i, j, k} : std::vector{rand_indx_i, k};
|
||||
if (input.at(1)->type().IsSignedInt()) {
|
||||
int64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index));
|
||||
RETURN_IF_NOT_OK(
|
||||
(*out_labels)->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value));
|
||||
} else {
|
||||
uint64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index));
|
||||
RETURN_IF_NOT_OK(
|
||||
(*out_labels)->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value));
|
||||
}
|
||||
float first_value, second_value;
|
||||
RETURN_IF_NOT_OK(float_label->GetItemAt(&first_value, first_index));
|
||||
RETURN_IF_NOT_OK(float_label->GetItemAt(&second_value, second_index));
|
||||
RETURN_IF_NOT_OK((*out_labels)->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -208,7 +199,7 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
|
||||
// Tensor holding the output labels
|
||||
std::shared_ptr<Tensor> out_labels;
|
||||
RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), &out_labels, DataType(DataType::DE_FLOAT32)));
|
||||
RETURN_IF_NOT_OK(TypeCast(input.at(1), &out_labels, DataType(DataType::DE_FLOAT32)));
|
||||
int64_t row_labels = label_shape.size() == kValueThree ? label_shape[kDimensionOne] : kValueOne;
|
||||
int64_t num_classes = label_shape.size() == kValueThree ? label_shape[kDimensionTwo] : label_shape[kDimensionOne];
|
||||
|
||||
|
@ -227,9 +218,9 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
if (random_number < prob_) {
|
||||
float label_lam; // lambda used for labels
|
||||
// Compute image
|
||||
RETURN_IF_NOT_OK(ComputeImage(input, rand_indx[i], lam, &label_lam, &images[i]));
|
||||
RETURN_IF_NOT_OK(ComputeImage(input.at(0), rand_indx[i], lam, &label_lam, &images[i]));
|
||||
// Compute labels
|
||||
RETURN_IF_NOT_OK(ComputeLabel(input, rand_indx[i], static_cast<int64_t>(i), row_labels, num_classes,
|
||||
RETURN_IF_NOT_OK(ComputeLabel(input.at(1), rand_indx[i], static_cast<int64_t>(i), row_labels, num_classes,
|
||||
label_shape.size(), label_lam, &out_labels));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ class CutMixBatchOp : public TensorOp {
|
|||
/// \param[in] label_lam Lambda used for labels, will be updated after computing each image.
|
||||
/// \param[in] image_i The result of the i-th computed image.
|
||||
/// \returns Status
|
||||
Status ComputeImage(const TensorRow &input, const int64_t rand_indx_i, const float lam, float *label_lam,
|
||||
Status ComputeImage(const std::shared_ptr<Tensor> &image, int64_t rand_indx_i, float lam, float *label_lam,
|
||||
std::shared_ptr<Tensor> *image_i);
|
||||
|
||||
/// \brief Helper function used in Compute to compute each label corresponding to each image.
|
||||
|
@ -67,9 +67,10 @@ class CutMixBatchOp : public TensorOp {
|
|||
/// \param[in] label_lam Lambda used for setting the location.
|
||||
/// \param[in] out_labels The output of the i-th label, corresponding to the i-th computed image.
|
||||
/// \returns Status
|
||||
Status ComputeLabel(const TensorRow &input, const int64_t rand_indx_i, const int64_t index_i,
|
||||
const int64_t row_labels, const int64_t num_classes, const std::size_t label_shape_size,
|
||||
const float label_lam, std::shared_ptr<Tensor> *out_labels);
|
||||
Status ComputeLabel(const std::shared_ptr<Tensor> &label, int64_t rand_indx_i, int64_t index_i, int64_t row_labels,
|
||||
int64_t num_classes, std::size_t label_shape_size, float label_lam,
|
||||
std::shared_ptr<Tensor> *out_labels);
|
||||
|
||||
float alpha_;
|
||||
float prob_;
|
||||
ImageBatchFormat image_batch_format_;
|
||||
|
|
|
@ -39,9 +39,9 @@ constexpr int64_t value_three = 3;
|
|||
|
||||
MixUpBatchOp::MixUpBatchOp(float alpha) : alpha_(alpha) { rnd_.seed(GetSeed()); }
|
||||
|
||||
Status MixUpBatchOp::ComputeLabels(const TensorRow &input, std::shared_ptr<Tensor> *out_labels,
|
||||
std::vector<int64_t> *rand_indx, const std::vector<int64_t> &label_shape,
|
||||
const float lam, const size_t images_size) {
|
||||
Status MixUpBatchOp::ComputeLabels(const std::shared_ptr<Tensor> &label, std::shared_ptr<Tensor> *out_labels,
|
||||
std::vector<int64_t> *rand_indx, const std::vector<int64_t> &label_shape, float lam,
|
||||
size_t images_size) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
images_size <= static_cast<size_t>(std::numeric_limits<int64_t>::max()),
|
||||
"The \'images_size\' must not be more than \'INT64_MAX\', but got: " + std::to_string(images_size));
|
||||
|
@ -50,7 +50,9 @@ Status MixUpBatchOp::ComputeLabels(const TensorRow &input, std::shared_ptr<Tenso
|
|||
}
|
||||
std::shuffle(rand_indx->begin(), rand_indx->end(), rnd_);
|
||||
|
||||
RETURN_IF_NOT_OK(TypeCast(input.at(1), out_labels, DataType(DataType::DE_FLOAT32)));
|
||||
std::shared_ptr<Tensor> float_label;
|
||||
RETURN_IF_NOT_OK(TypeCast(label, &float_label, DataType(DataType::DE_FLOAT32)));
|
||||
RETURN_IF_NOT_OK(TypeCast(label, out_labels, DataType(DataType::DE_FLOAT32)));
|
||||
|
||||
int64_t row_labels = label_shape.size() == kMaxLabelShapeSize ? label_shape[1] : 1;
|
||||
int64_t num_classes = label_shape.size() == kMaxLabelShapeSize ? label_shape[dimension_two] : label_shape[1];
|
||||
|
@ -63,17 +65,10 @@ Status MixUpBatchOp::ComputeLabels(const TensorRow &input, std::shared_ptr<Tenso
|
|||
std::vector<int64_t> second_index = label_shape.size() == kMaxLabelShapeSize
|
||||
? std::vector{(*rand_indx)[static_cast<size_t>(i)], j, k}
|
||||
: std::vector{(*rand_indx)[static_cast<size_t>(i)], k};
|
||||
if (input.at(1)->type().IsSignedInt()) {
|
||||
int64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index));
|
||||
RETURN_IF_NOT_OK((*out_labels)->SetItemAt(first_index, lam * first_value + (1 - lam) * second_value));
|
||||
} else {
|
||||
uint64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index));
|
||||
RETURN_IF_NOT_OK((*out_labels)->SetItemAt(first_index, lam * first_value + (1 - lam) * second_value));
|
||||
}
|
||||
float first_value, second_value;
|
||||
RETURN_IF_NOT_OK(float_label->GetItemAt(&first_value, first_index));
|
||||
RETURN_IF_NOT_OK(float_label->GetItemAt(&second_value, second_index));
|
||||
RETURN_IF_NOT_OK((*out_labels)->SetItemAt(first_index, lam * first_value + (1 - lam) * second_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -96,12 +91,10 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
", but got: " + std::to_string(image_shape.size()) +
|
||||
", make sure image shape are <H,W,C> or <C,H,W> and batched before calling MixUpBatch.");
|
||||
}
|
||||
if (!input.at(1)->type().IsInt()) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"MixUpBatch: wrong labels type. The second column (labels) must only include int types, but got: " +
|
||||
input.at(1)->type().ToString());
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input.at(1)->type().IsNumeric(),
|
||||
"MixUpBatch: invalid label type, label must be in a numeric type, but got: " +
|
||||
input.at(1)->type().ToString() + ". You may need to perform OneHot first.");
|
||||
if (label_shape.size() != kMinLabelShapeSize && label_shape.size() != kMaxLabelShapeSize) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"MixUpBatch: wrong labels shape. "
|
||||
|
@ -137,7 +130,7 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
std::shared_ptr<Tensor> out_labels;
|
||||
|
||||
// Compute labels
|
||||
RETURN_IF_NOT_OK(ComputeLabels(input, &out_labels, &rand_indx, label_shape, lam, images.size()));
|
||||
RETURN_IF_NOT_OK(ComputeLabels(input.at(1), &out_labels, &rand_indx, label_shape, lam, images.size()));
|
||||
|
||||
// Compute images
|
||||
for (int64_t i = 0; i < images.size(); i++) {
|
||||
|
|
|
@ -43,8 +43,10 @@ class MixUpBatchOp : public TensorOp {
|
|||
|
||||
private:
|
||||
// a helper function to shorten the main Compute function
|
||||
Status ComputeLabels(const TensorRow &input, std::shared_ptr<Tensor> *out_labels, std::vector<int64_t> *rand_indx,
|
||||
const std::vector<int64_t> &label_shape, const float lam, const size_t images_size);
|
||||
Status ComputeLabels(const std::shared_ptr<Tensor> &label, std::shared_ptr<Tensor> *out_labels,
|
||||
std::vector<int64_t> *rand_indx, const std::vector<int64_t> &label_shape, float lam,
|
||||
size_t images_size);
|
||||
|
||||
float alpha_;
|
||||
std::mt19937 rnd_;
|
||||
};
|
||||
|
|
|
@ -19,7 +19,7 @@ import numpy as np
|
|||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision as vision
|
||||
import mindspore.dataset.transforms as data_trans
|
||||
import mindspore.dataset.transforms as transforms
|
||||
import mindspore.dataset.vision.utils as mode
|
||||
from mindspore import log as logger
|
||||
from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
|
||||
|
@ -54,7 +54,7 @@ def test_cutmix_batch_success1(plot=False):
|
|||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
hwc2chw_op = vision.HWC2CHW()
|
||||
data1 = data1.map(operations=hwc2chw_op, input_columns=["image"])
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5)
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
|
@ -97,7 +97,7 @@ def test_cutmix_batch_success2(plot=False):
|
|||
|
||||
# CutMix Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
rescale_op = vision.Rescale((1.0 / 255.0), 0.0)
|
||||
data1 = data1.map(operations=rescale_op, input_columns=["image"])
|
||||
|
@ -152,7 +152,7 @@ def test_cutmix_batch_success3(plot=False):
|
|||
resize_op = vision.Resize([224, 224])
|
||||
data1 = data1.map(operations=[resize_op], input_columns=["image"])
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||
|
@ -206,7 +206,7 @@ def test_cutmix_batch_success4(plot=False):
|
|||
resize_op = vision.Resize([224, 224])
|
||||
data1 = data1.map(operations=[resize_op], input_columns=["image"])
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=100)
|
||||
one_hot_op = transforms.OneHot(num_classes=100)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["attr"])
|
||||
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.5, 0.9)
|
||||
|
@ -242,7 +242,7 @@ def test_cutmix_batch_nhwc_md5():
|
|||
# CutMixBatch Images
|
||||
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data = data.map(operations=one_hot_op, input_columns=["label"])
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||
data = data.batch(5, drop_remainder=True)
|
||||
|
@ -270,7 +270,7 @@ def test_cutmix_batch_nchw_md5():
|
|||
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
hwc2chw_op = vision.HWC2CHW()
|
||||
data = data.map(operations=hwc2chw_op, input_columns=["image"])
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data = data.map(operations=one_hot_op, input_columns=["label"])
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
|
||||
data = data.batch(5, drop_remainder=True)
|
||||
|
@ -284,6 +284,52 @@ def test_cutmix_batch_nchw_md5():
|
|||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_cutmix_batch_float_label():
|
||||
"""
|
||||
Feature: CutMixBatch
|
||||
Description: Test CutMixBatch with label in type of float
|
||||
Expectation: Output is as expected
|
||||
"""
|
||||
original_seed = config_get_set_seed(0)
|
||||
|
||||
image = np.random.randint(0, 255, (3, 28, 28, 1), dtype=np.uint8)
|
||||
label = np.random.randint(0, 5, (3, 1))
|
||||
decode_label = transforms.OneHot(5)(label)
|
||||
float_label = transforms.TypeCast(float)(decode_label)
|
||||
_, mix_label = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)(image, float_label)
|
||||
expected_label = np.array([[0., 0.9285714, 0., 0., 0.0714286],
|
||||
[0., 0., 0., 0., 1.],
|
||||
[0., 0.02040815, 0., 0., 0.97959185]])
|
||||
np.testing.assert_almost_equal(mix_label, expected_label)
|
||||
|
||||
ds.config.set_seed(original_seed)
|
||||
|
||||
|
||||
def test_cutmix_batch_twice():
|
||||
"""
|
||||
Feature: CutMixBatch
|
||||
Description: Test CutMixBatch called twice
|
||||
Expectation: Output is as expected
|
||||
"""
|
||||
original_seed = config_get_set_seed(5)
|
||||
|
||||
dataset = ds.Cifar10Dataset(DATA_DIR, num_samples=3, shuffle=False)
|
||||
one_hot = transforms.OneHot(num_classes=10)
|
||||
dataset = dataset.map(operations=one_hot, input_columns=["label"])
|
||||
cut_mix_batch = vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 2.0, 0.5)
|
||||
dataset = dataset.batch(3, drop_remainder=False)
|
||||
dataset = dataset.map(operations=cut_mix_batch, input_columns=["image", "label"])
|
||||
dataset = dataset.map(operations=cut_mix_batch, input_columns=["image", "label"])
|
||||
|
||||
expected_label = np.array([[0.58618164, 0.41107178, 0.00274658, 0., 0., 0., 0., 0., 0., 0.],
|
||||
[0.00109863, 0.9766998, 0.02220154, 0., 0., 0., 0., 0., 0., 0.],
|
||||
[0.15673828, 0.02197266, 0.82128906, 0., 0., 0., 0., 0., 0., 0.]])
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_almost_equal(item["label"], expected_label)
|
||||
|
||||
ds.config.set_seed(original_seed)
|
||||
|
||||
|
||||
def test_cutmix_batch_fail1():
|
||||
"""
|
||||
Feature: CutMixBatch op
|
||||
|
@ -295,7 +341,7 @@ def test_cutmix_batch_fail1():
|
|||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||
with pytest.raises(RuntimeError) as error:
|
||||
|
@ -320,7 +366,7 @@ def test_cutmix_batch_fail2():
|
|||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
with pytest.raises(ValueError) as error:
|
||||
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1)
|
||||
|
@ -339,7 +385,7 @@ def test_cutmix_batch_fail3():
|
|||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
with pytest.raises(ValueError) as error:
|
||||
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2)
|
||||
|
@ -358,7 +404,7 @@ def test_cutmix_batch_fail4():
|
|||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
with pytest.raises(ValueError) as error:
|
||||
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1)
|
||||
|
@ -377,7 +423,7 @@ def test_cutmix_batch_fail5():
|
|||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
|
@ -405,7 +451,7 @@ def test_cutmix_batch_fail6():
|
|||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
|
@ -459,7 +505,7 @@ def test_cutmix_batch_fail8():
|
|||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
with pytest.raises(ValueError) as error:
|
||||
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.0)
|
||||
|
@ -467,6 +513,20 @@ def test_cutmix_batch_fail8():
|
|||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
def test_cut_mix_batch_invalid_label_type():
|
||||
"""
|
||||
Feature: CutMixBatch
|
||||
Description: Test CutMixBatch with label in str type
|
||||
Expectation: Error is raised as expected
|
||||
"""
|
||||
image = np.random.randint(0, 255, (3, 28, 28, 1), dtype=np.uint8)
|
||||
label = np.array([["one"], ["two"], ["three"]])
|
||||
with pytest.raises(RuntimeError) as error:
|
||||
_ = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)(image, label)
|
||||
error_message = "CutMixBatch: invalid label type, label must be in a numeric type"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cutmix_batch_success1(plot=True)
|
||||
test_cutmix_batch_success2(plot=True)
|
||||
|
|
|
@ -19,7 +19,7 @@ import numpy as np
|
|||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision as vision
|
||||
import mindspore.dataset.transforms as data_trans
|
||||
import mindspore.dataset.transforms as transforms
|
||||
from mindspore import log as logger
|
||||
from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
|
||||
config_get_set_num_parallel_workers
|
||||
|
@ -53,7 +53,7 @@ def test_mixup_batch_success1(plot=False):
|
|||
# MixUp Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
mixup_batch_op = vision.MixUpBatch(2)
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
|
@ -102,7 +102,7 @@ def test_mixup_batch_success2(plot=False):
|
|||
decode_op = vision.Decode()
|
||||
data1 = data1.map(operations=[decode_op], input_columns=["image"])
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
|
||||
mixup_batch_op = vision.MixUpBatch(2.0)
|
||||
|
@ -147,7 +147,7 @@ def test_mixup_batch_success3(plot=False):
|
|||
# MixUp Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
mixup_batch_op = vision.MixUpBatch()
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
|
@ -196,7 +196,7 @@ def test_mixup_batch_success4(plot=False):
|
|||
decode_op = vision.Decode()
|
||||
data1 = data1.map(operations=[decode_op], input_columns=["image"])
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=100)
|
||||
one_hot_op = transforms.OneHot(num_classes=100)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["attr"])
|
||||
|
||||
mixup_batch_op = vision.MixUpBatch()
|
||||
|
@ -232,7 +232,7 @@ def test_mixup_batch_md5():
|
|||
# MixUp Images
|
||||
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data = data.map(operations=one_hot_op, input_columns=["label"])
|
||||
mixup_batch_op = vision.MixUpBatch()
|
||||
data = data.batch(5, drop_remainder=True)
|
||||
|
@ -246,6 +246,52 @@ def test_mixup_batch_md5():
|
|||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_mixup_batch_float_label():
|
||||
"""
|
||||
Feature: MixUpBatch
|
||||
Description: Test MixUpBatch with label in type of float
|
||||
Expectation: Output is as expected
|
||||
"""
|
||||
original_seed = config_get_set_seed(0)
|
||||
|
||||
image = np.random.randint(0, 255, (3, 28, 28, 1), dtype=np.uint8)
|
||||
label = np.random.randint(0, 5, (3, 1))
|
||||
decode_label = transforms.OneHot(5)(label)
|
||||
float_label = transforms.TypeCast(float)(decode_label)
|
||||
_, mix_label = vision.MixUpBatch()(image, float_label)
|
||||
expected_label = np.array([[0., 0.6824126, 0., 0., 0.3175874],
|
||||
[0., 0., 0., 0., 1.],
|
||||
[0., 0.3175874, 0., 0., 0.6824126]])
|
||||
np.testing.assert_almost_equal(mix_label, expected_label)
|
||||
|
||||
ds.config.set_seed(original_seed)
|
||||
|
||||
|
||||
def test_mixup_batch_twice():
|
||||
"""
|
||||
Feature: MixUpBatch
|
||||
Description: Test MixUpBatch called twice
|
||||
Expectation: Output is as expected
|
||||
"""
|
||||
original_seed = config_get_set_seed(1)
|
||||
|
||||
dataset = ds.Cifar10Dataset(DATA_DIR, num_samples=3, shuffle=False)
|
||||
one_hot = transforms.OneHot(num_classes=10)
|
||||
dataset = dataset.map(operations=one_hot, input_columns=["label"])
|
||||
mix_up_batch = vision.MixUpBatch()
|
||||
dataset = dataset.batch(3, drop_remainder=False)
|
||||
dataset = dataset.map(operations=mix_up_batch, input_columns=["image", "label"])
|
||||
dataset = dataset.map(operations=mix_up_batch, input_columns=["image", "label"])
|
||||
|
||||
expected_label = np.array([[0.29373336, 0.49647674, 0.20978989, 0., 0., 0., 0., 0., 0., 0.],
|
||||
[0.20978989, 0.29373336, 0.49647674, 0., 0., 0., 0., 0., 0., 0.],
|
||||
[0.49647674, 0.20978989, 0.29373336, 0., 0., 0., 0., 0., 0., 0.]])
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_almost_equal(item["label"], expected_label)
|
||||
|
||||
ds.config.set_seed(original_seed)
|
||||
|
||||
|
||||
def test_mixup_batch_fail1():
|
||||
"""
|
||||
Feature: MixUpBatch op
|
||||
|
@ -268,7 +314,7 @@ def test_mixup_batch_fail1():
|
|||
# MixUp Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
mixup_batch_op = vision.MixUpBatch(0.1)
|
||||
with pytest.raises(RuntimeError) as error:
|
||||
|
@ -304,7 +350,7 @@ def test_mixup_batch_fail2():
|
|||
# MixUp Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
with pytest.raises(ValueError) as error:
|
||||
vision.MixUpBatch(-1)
|
||||
|
@ -333,7 +379,7 @@ def test_mixup_batch_fail3():
|
|||
# MixUp Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
mixup_batch_op = vision.MixUpBatch()
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
|
@ -372,7 +418,7 @@ def test_mixup_batch_fail4():
|
|||
# MixUp Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
one_hot_op = transforms.OneHot(num_classes=10)
|
||||
data1 = data1.map(operations=one_hot_op, input_columns=["label"])
|
||||
with pytest.raises(ValueError) as error:
|
||||
vision.MixUpBatch(0.0)
|
||||
|
@ -417,6 +463,20 @@ def test_mixup_batch_fail5():
|
|||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
def test_mix_up_batch_invalid_label_type():
|
||||
"""
|
||||
Feature: MixUpBatch
|
||||
Description: Test MixUpBatch with label in str type
|
||||
Expectation: Error is raised as expected
|
||||
"""
|
||||
image = np.random.randint(0, 255, (3, 28, 28, 1), dtype=np.uint8)
|
||||
label = np.array([["one"], ["two"], ["three"]])
|
||||
with pytest.raises(RuntimeError) as error:
|
||||
_ = vision.MixUpBatch()(image, label)
|
||||
error_message = "MixUpBatch: invalid label type, label must be in a numeric type"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mixup_batch_success1(plot=True)
|
||||
test_mixup_batch_success2(plot=True)
|
||||
|
|
Loading…
Reference in New Issue