!45156 [MD] Support float label in MixUpBatch and CutMixBatch

Merge pull request !45156 from xiaotianci/mix_float_label
This commit is contained in:
i-robot 2022-11-08 02:04:46 +00:00 committed by Gitee
commit 72389ba6d3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 189 additions and 82 deletions

View File

@ -75,11 +75,10 @@ Status CutMixBatchOp::ValidateCutMixBatch(const TensorRow &input) {
RETURN_STATUS_UNEXPECTED( RETURN_STATUS_UNEXPECTED(
"CutMixBatch: please make sure images are <H,W,C> or <C,H,W> format, and batched before calling CutMixBatch."); "CutMixBatch: please make sure images are <H,W,C> or <C,H,W> format, and batched before calling CutMixBatch.");
} }
if (!input.at(1)->type().IsInt()) {
RETURN_STATUS_UNEXPECTED( CHECK_FAIL_RETURN_UNEXPECTED(input.at(1)->type().IsNumeric(),
"CutMixBatch: Wrong labels type. The second column (labels) must only include int types, but got:" + "CutMixBatch: invalid label type, label must be in a numeric type, but got: " +
input.at(1)->type().ToString()); input.at(1)->type().ToString() + ". You may need to perform OneHot first.");
}
if (label_shape.size() != kMinLabelShapeSize && label_shape.size() != kMaxLabelShapeSize) { if (label_shape.size() != kMinLabelShapeSize && label_shape.size() != kMaxLabelShapeSize) {
RETURN_STATUS_UNEXPECTED( RETURN_STATUS_UNEXPECTED(
"CutMixBatch: wrong labels shape. " "CutMixBatch: wrong labels shape. "
@ -105,19 +104,19 @@ Status CutMixBatchOp::ValidateCutMixBatch(const TensorRow &input) {
return Status::OK(); return Status::OK();
} }
Status CutMixBatchOp::ComputeImage(const TensorRow &input, const int64_t rand_indx_i, const float lam, float *label_lam, Status CutMixBatchOp::ComputeImage(const std::shared_ptr<Tensor> &image, int64_t rand_indx_i, float lam,
std::shared_ptr<Tensor> *image_i) { float *label_lam, std::shared_ptr<Tensor> *image_i) {
std::vector<int64_t> image_shape = input.at(0)->shape().AsVector(); std::vector<int64_t> image_shape = image->shape().AsVector();
int x, y, crop_width, crop_height; int x, y, crop_width, crop_height;
// Get a random image // Get a random image
TensorShape remaining({-1}); TensorShape remaining({-1});
uchar *start_addr_of_index = nullptr; uchar *start_addr_of_index = nullptr;
std::shared_ptr<Tensor> rand_image; std::shared_ptr<Tensor> rand_image;
RETURN_IF_NOT_OK(input.at(0)->StartAddrOfIndex({rand_indx_i, 0, 0, 0}, &start_addr_of_index, &remaining)); RETURN_IF_NOT_OK(image->StartAddrOfIndex({rand_indx_i, 0, 0, 0}, &start_addr_of_index, &remaining));
RETURN_IF_NOT_OK(Tensor::CreateFromMemory( RETURN_IF_NOT_OK(Tensor::CreateFromMemory(
TensorShape({image_shape[kDimensionOne], image_shape[kDimensionTwo], image_shape[kDimensionThree]}), TensorShape({image_shape[kDimensionOne], image_shape[kDimensionTwo], image_shape[kDimensionThree]}), image->type(),
input.at(0)->type(), start_addr_of_index, &rand_image)); start_addr_of_index, &rand_image));
// Compute image // Compute image
if (image_batch_format_ == ImageBatchFormat::kNHWC) { if (image_batch_format_ == ImageBatchFormat::kNHWC) {
@ -154,30 +153,22 @@ Status CutMixBatchOp::ComputeImage(const TensorRow &input, const int64_t rand_in
return Status::OK(); return Status::OK();
} }
Status CutMixBatchOp::ComputeLabel(const TensorRow &input, const int64_t rand_indx_i, const int64_t index_i, Status CutMixBatchOp::ComputeLabel(const std::shared_ptr<Tensor> &label, int64_t rand_indx_i, int64_t index_i,
const int64_t row_labels, const int64_t num_classes, int64_t row_labels, int64_t num_classes, std::size_t label_shape_size,
const std::size_t label_shape_size, const float label_lam, float label_lam, std::shared_ptr<Tensor> *out_labels) {
std::shared_ptr<Tensor> *out_labels) {
// Compute labels // Compute labels
std::shared_ptr<Tensor> float_label;
RETURN_IF_NOT_OK(TypeCast(label, &float_label, DataType(DataType::DE_FLOAT32)));
for (int64_t j = 0; j < row_labels; j++) { for (int64_t j = 0; j < row_labels; j++) {
for (int64_t k = 0; k < num_classes; k++) { for (int64_t k = 0; k < num_classes; k++) {
std::vector<int64_t> first_index = std::vector<int64_t> first_index =
label_shape_size == kMaxLabelShapeSize ? std::vector{index_i, j, k} : std::vector{index_i, k}; label_shape_size == kMaxLabelShapeSize ? std::vector{index_i, j, k} : std::vector{index_i, k};
std::vector<int64_t> second_index = std::vector<int64_t> second_index =
label_shape_size == kMaxLabelShapeSize ? std::vector{rand_indx_i, j, k} : std::vector{rand_indx_i, k}; label_shape_size == kMaxLabelShapeSize ? std::vector{rand_indx_i, j, k} : std::vector{rand_indx_i, k};
if (input.at(1)->type().IsSignedInt()) { float first_value, second_value;
int64_t first_value, second_value; RETURN_IF_NOT_OK(float_label->GetItemAt(&first_value, first_index));
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index)); RETURN_IF_NOT_OK(float_label->GetItemAt(&second_value, second_index));
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index)); RETURN_IF_NOT_OK((*out_labels)->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value));
RETURN_IF_NOT_OK(
(*out_labels)->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value));
} else {
uint64_t first_value, second_value;
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index));
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index));
RETURN_IF_NOT_OK(
(*out_labels)->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value));
}
} }
} }
@ -208,7 +199,7 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
// Tensor holding the output labels // Tensor holding the output labels
std::shared_ptr<Tensor> out_labels; std::shared_ptr<Tensor> out_labels;
RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), &out_labels, DataType(DataType::DE_FLOAT32))); RETURN_IF_NOT_OK(TypeCast(input.at(1), &out_labels, DataType(DataType::DE_FLOAT32)));
int64_t row_labels = label_shape.size() == kValueThree ? label_shape[kDimensionOne] : kValueOne; int64_t row_labels = label_shape.size() == kValueThree ? label_shape[kDimensionOne] : kValueOne;
int64_t num_classes = label_shape.size() == kValueThree ? label_shape[kDimensionTwo] : label_shape[kDimensionOne]; int64_t num_classes = label_shape.size() == kValueThree ? label_shape[kDimensionTwo] : label_shape[kDimensionOne];
@ -227,9 +218,9 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
if (random_number < prob_) { if (random_number < prob_) {
float label_lam; // lambda used for labels float label_lam; // lambda used for labels
// Compute image // Compute image
RETURN_IF_NOT_OK(ComputeImage(input, rand_indx[i], lam, &label_lam, &images[i])); RETURN_IF_NOT_OK(ComputeImage(input.at(0), rand_indx[i], lam, &label_lam, &images[i]));
// Compute labels // Compute labels
RETURN_IF_NOT_OK(ComputeLabel(input, rand_indx[i], static_cast<int64_t>(i), row_labels, num_classes, RETURN_IF_NOT_OK(ComputeLabel(input.at(1), rand_indx[i], static_cast<int64_t>(i), row_labels, num_classes,
label_shape.size(), label_lam, &out_labels)); label_shape.size(), label_lam, &out_labels));
} }
} }

View File

@ -54,7 +54,7 @@ class CutMixBatchOp : public TensorOp {
/// \param[in] label_lam Lambda used for labels, will be updated after computing each image. /// \param[in] label_lam Lambda used for labels, will be updated after computing each image.
/// \param[in] image_i The result of the i-th computed image. /// \param[in] image_i The result of the i-th computed image.
/// \returns Status /// \returns Status
Status ComputeImage(const TensorRow &input, const int64_t rand_indx_i, const float lam, float *label_lam, Status ComputeImage(const std::shared_ptr<Tensor> &image, int64_t rand_indx_i, float lam, float *label_lam,
std::shared_ptr<Tensor> *image_i); std::shared_ptr<Tensor> *image_i);
/// \brief Helper function used in Compute to compute each label corresponding to each image. /// \brief Helper function used in Compute to compute each label corresponding to each image.
@ -67,9 +67,10 @@ class CutMixBatchOp : public TensorOp {
/// \param[in] label_lam Lambda used for setting the location. /// \param[in] label_lam Lambda used for setting the location.
/// \param[in] out_labels The output of the i-th label, corresponding to the i-th computed image. /// \param[in] out_labels The output of the i-th label, corresponding to the i-th computed image.
/// \returns Status /// \returns Status
Status ComputeLabel(const TensorRow &input, const int64_t rand_indx_i, const int64_t index_i, Status ComputeLabel(const std::shared_ptr<Tensor> &label, int64_t rand_indx_i, int64_t index_i, int64_t row_labels,
const int64_t row_labels, const int64_t num_classes, const std::size_t label_shape_size, int64_t num_classes, std::size_t label_shape_size, float label_lam,
const float label_lam, std::shared_ptr<Tensor> *out_labels); std::shared_ptr<Tensor> *out_labels);
float alpha_; float alpha_;
float prob_; float prob_;
ImageBatchFormat image_batch_format_; ImageBatchFormat image_batch_format_;

View File

@ -39,9 +39,9 @@ constexpr int64_t value_three = 3;
MixUpBatchOp::MixUpBatchOp(float alpha) : alpha_(alpha) { rnd_.seed(GetSeed()); } MixUpBatchOp::MixUpBatchOp(float alpha) : alpha_(alpha) { rnd_.seed(GetSeed()); }
Status MixUpBatchOp::ComputeLabels(const TensorRow &input, std::shared_ptr<Tensor> *out_labels, Status MixUpBatchOp::ComputeLabels(const std::shared_ptr<Tensor> &label, std::shared_ptr<Tensor> *out_labels,
std::vector<int64_t> *rand_indx, const std::vector<int64_t> &label_shape, std::vector<int64_t> *rand_indx, const std::vector<int64_t> &label_shape, float lam,
const float lam, const size_t images_size) { size_t images_size) {
CHECK_FAIL_RETURN_UNEXPECTED( CHECK_FAIL_RETURN_UNEXPECTED(
images_size <= static_cast<size_t>(std::numeric_limits<int64_t>::max()), images_size <= static_cast<size_t>(std::numeric_limits<int64_t>::max()),
"The \'images_size\' must not be more than \'INT64_MAX\', but got: " + std::to_string(images_size)); "The \'images_size\' must not be more than \'INT64_MAX\', but got: " + std::to_string(images_size));
@ -50,7 +50,9 @@ Status MixUpBatchOp::ComputeLabels(const TensorRow &input, std::shared_ptr<Tenso
} }
std::shuffle(rand_indx->begin(), rand_indx->end(), rnd_); std::shuffle(rand_indx->begin(), rand_indx->end(), rnd_);
RETURN_IF_NOT_OK(TypeCast(input.at(1), out_labels, DataType(DataType::DE_FLOAT32))); std::shared_ptr<Tensor> float_label;
RETURN_IF_NOT_OK(TypeCast(label, &float_label, DataType(DataType::DE_FLOAT32)));
RETURN_IF_NOT_OK(TypeCast(label, out_labels, DataType(DataType::DE_FLOAT32)));
int64_t row_labels = label_shape.size() == kMaxLabelShapeSize ? label_shape[1] : 1; int64_t row_labels = label_shape.size() == kMaxLabelShapeSize ? label_shape[1] : 1;
int64_t num_classes = label_shape.size() == kMaxLabelShapeSize ? label_shape[dimension_two] : label_shape[1]; int64_t num_classes = label_shape.size() == kMaxLabelShapeSize ? label_shape[dimension_two] : label_shape[1];
@ -63,17 +65,10 @@ Status MixUpBatchOp::ComputeLabels(const TensorRow &input, std::shared_ptr<Tenso
std::vector<int64_t> second_index = label_shape.size() == kMaxLabelShapeSize std::vector<int64_t> second_index = label_shape.size() == kMaxLabelShapeSize
? std::vector{(*rand_indx)[static_cast<size_t>(i)], j, k} ? std::vector{(*rand_indx)[static_cast<size_t>(i)], j, k}
: std::vector{(*rand_indx)[static_cast<size_t>(i)], k}; : std::vector{(*rand_indx)[static_cast<size_t>(i)], k};
if (input.at(1)->type().IsSignedInt()) { float first_value, second_value;
int64_t first_value, second_value; RETURN_IF_NOT_OK(float_label->GetItemAt(&first_value, first_index));
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index)); RETURN_IF_NOT_OK(float_label->GetItemAt(&second_value, second_index));
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index)); RETURN_IF_NOT_OK((*out_labels)->SetItemAt(first_index, lam * first_value + (1 - lam) * second_value));
RETURN_IF_NOT_OK((*out_labels)->SetItemAt(first_index, lam * first_value + (1 - lam) * second_value));
} else {
uint64_t first_value, second_value;
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index));
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index));
RETURN_IF_NOT_OK((*out_labels)->SetItemAt(first_index, lam * first_value + (1 - lam) * second_value));
}
} }
} }
} }
@ -96,12 +91,10 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
", but got: " + std::to_string(image_shape.size()) + ", but got: " + std::to_string(image_shape.size()) +
", make sure image shape are <H,W,C> or <C,H,W> and batched before calling MixUpBatch."); ", make sure image shape are <H,W,C> or <C,H,W> and batched before calling MixUpBatch.");
} }
if (!input.at(1)->type().IsInt()) {
RETURN_STATUS_UNEXPECTED(
"MixUpBatch: wrong labels type. The second column (labels) must only include int types, but got: " +
input.at(1)->type().ToString());
}
CHECK_FAIL_RETURN_UNEXPECTED(input.at(1)->type().IsNumeric(),
"MixUpBatch: invalid label type, label must be in a numeric type, but got: " +
input.at(1)->type().ToString() + ". You may need to perform OneHot first.");
if (label_shape.size() != kMinLabelShapeSize && label_shape.size() != kMaxLabelShapeSize) { if (label_shape.size() != kMinLabelShapeSize && label_shape.size() != kMaxLabelShapeSize) {
RETURN_STATUS_UNEXPECTED( RETURN_STATUS_UNEXPECTED(
"MixUpBatch: wrong labels shape. " "MixUpBatch: wrong labels shape. "
@ -137,7 +130,7 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
std::shared_ptr<Tensor> out_labels; std::shared_ptr<Tensor> out_labels;
// Compute labels // Compute labels
RETURN_IF_NOT_OK(ComputeLabels(input, &out_labels, &rand_indx, label_shape, lam, images.size())); RETURN_IF_NOT_OK(ComputeLabels(input.at(1), &out_labels, &rand_indx, label_shape, lam, images.size()));
// Compute images // Compute images
for (int64_t i = 0; i < images.size(); i++) { for (int64_t i = 0; i < images.size(); i++) {

View File

@ -43,8 +43,10 @@ class MixUpBatchOp : public TensorOp {
private: private:
// a helper function to shorten the main Compute function // a helper function to shorten the main Compute function
Status ComputeLabels(const TensorRow &input, std::shared_ptr<Tensor> *out_labels, std::vector<int64_t> *rand_indx, Status ComputeLabels(const std::shared_ptr<Tensor> &label, std::shared_ptr<Tensor> *out_labels,
const std::vector<int64_t> &label_shape, const float lam, const size_t images_size); std::vector<int64_t> *rand_indx, const std::vector<int64_t> &label_shape, float lam,
size_t images_size);
float alpha_; float alpha_;
std::mt19937 rnd_; std::mt19937 rnd_;
}; };

View File

@ -19,7 +19,7 @@ import numpy as np
import pytest import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.vision as vision import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as data_trans import mindspore.dataset.transforms as transforms
import mindspore.dataset.vision.utils as mode import mindspore.dataset.vision.utils as mode
from mindspore import log as logger from mindspore import log as logger
from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
@ -54,7 +54,7 @@ def test_cutmix_batch_success1(plot=False):
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
hwc2chw_op = vision.HWC2CHW() hwc2chw_op = vision.HWC2CHW()
data1 = data1.map(operations=hwc2chw_op, input_columns=["image"]) data1 = data1.map(operations=hwc2chw_op, input_columns=["image"])
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5) cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5)
data1 = data1.batch(5, drop_remainder=True) data1 = data1.batch(5, drop_remainder=True)
@ -97,7 +97,7 @@ def test_cutmix_batch_success2(plot=False):
# CutMix Images # CutMix Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
rescale_op = vision.Rescale((1.0 / 255.0), 0.0) rescale_op = vision.Rescale((1.0 / 255.0), 0.0)
data1 = data1.map(operations=rescale_op, input_columns=["image"]) data1 = data1.map(operations=rescale_op, input_columns=["image"])
@ -152,7 +152,7 @@ def test_cutmix_batch_success3(plot=False):
resize_op = vision.Resize([224, 224]) resize_op = vision.Resize([224, 224])
data1 = data1.map(operations=[resize_op], input_columns=["image"]) data1 = data1.map(operations=[resize_op], input_columns=["image"])
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
@ -206,7 +206,7 @@ def test_cutmix_batch_success4(plot=False):
resize_op = vision.Resize([224, 224]) resize_op = vision.Resize([224, 224])
data1 = data1.map(operations=[resize_op], input_columns=["image"]) data1 = data1.map(operations=[resize_op], input_columns=["image"])
one_hot_op = data_trans.OneHot(num_classes=100) one_hot_op = transforms.OneHot(num_classes=100)
data1 = data1.map(operations=one_hot_op, input_columns=["attr"]) data1 = data1.map(operations=one_hot_op, input_columns=["attr"])
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.5, 0.9) cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.5, 0.9)
@ -242,7 +242,7 @@ def test_cutmix_batch_nhwc_md5():
# CutMixBatch Images # CutMixBatch Images
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data = data.map(operations=one_hot_op, input_columns=["label"]) data = data.map(operations=one_hot_op, input_columns=["label"])
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
data = data.batch(5, drop_remainder=True) data = data.batch(5, drop_remainder=True)
@ -270,7 +270,7 @@ def test_cutmix_batch_nchw_md5():
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
hwc2chw_op = vision.HWC2CHW() hwc2chw_op = vision.HWC2CHW()
data = data.map(operations=hwc2chw_op, input_columns=["image"]) data = data.map(operations=hwc2chw_op, input_columns=["image"])
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data = data.map(operations=one_hot_op, input_columns=["label"]) data = data.map(operations=one_hot_op, input_columns=["label"])
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW) cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
data = data.batch(5, drop_remainder=True) data = data.batch(5, drop_remainder=True)
@ -284,6 +284,52 @@ def test_cutmix_batch_nchw_md5():
ds.config.set_num_parallel_workers(original_num_parallel_workers) ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_cutmix_batch_float_label():
"""
Feature: CutMixBatch
Description: Test CutMixBatch with label in type of float
Expectation: Output is as expected
"""
original_seed = config_get_set_seed(0)
image = np.random.randint(0, 255, (3, 28, 28, 1), dtype=np.uint8)
label = np.random.randint(0, 5, (3, 1))
decode_label = transforms.OneHot(5)(label)
float_label = transforms.TypeCast(float)(decode_label)
_, mix_label = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)(image, float_label)
expected_label = np.array([[0., 0.9285714, 0., 0., 0.0714286],
[0., 0., 0., 0., 1.],
[0., 0.02040815, 0., 0., 0.97959185]])
np.testing.assert_almost_equal(mix_label, expected_label)
ds.config.set_seed(original_seed)
def test_cutmix_batch_twice():
"""
Feature: CutMixBatch
Description: Test CutMixBatch called twice
Expectation: Output is as expected
"""
original_seed = config_get_set_seed(5)
dataset = ds.Cifar10Dataset(DATA_DIR, num_samples=3, shuffle=False)
one_hot = transforms.OneHot(num_classes=10)
dataset = dataset.map(operations=one_hot, input_columns=["label"])
cut_mix_batch = vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 2.0, 0.5)
dataset = dataset.batch(3, drop_remainder=False)
dataset = dataset.map(operations=cut_mix_batch, input_columns=["image", "label"])
dataset = dataset.map(operations=cut_mix_batch, input_columns=["image", "label"])
expected_label = np.array([[0.58618164, 0.41107178, 0.00274658, 0., 0., 0., 0., 0., 0., 0.],
[0.00109863, 0.9766998, 0.02220154, 0., 0., 0., 0., 0., 0., 0.],
[0.15673828, 0.02197266, 0.82128906, 0., 0., 0., 0., 0., 0., 0.]])
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
np.testing.assert_almost_equal(item["label"], expected_label)
ds.config.set_seed(original_seed)
def test_cutmix_batch_fail1(): def test_cutmix_batch_fail1():
""" """
Feature: CutMixBatch op Feature: CutMixBatch op
@ -295,7 +341,7 @@ def test_cutmix_batch_fail1():
# CutMixBatch Images # CutMixBatch Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
with pytest.raises(RuntimeError) as error: with pytest.raises(RuntimeError) as error:
@ -320,7 +366,7 @@ def test_cutmix_batch_fail2():
# CutMixBatch Images # CutMixBatch Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
with pytest.raises(ValueError) as error: with pytest.raises(ValueError) as error:
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1) vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1)
@ -339,7 +385,7 @@ def test_cutmix_batch_fail3():
# CutMixBatch Images # CutMixBatch Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
with pytest.raises(ValueError) as error: with pytest.raises(ValueError) as error:
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2) vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2)
@ -358,7 +404,7 @@ def test_cutmix_batch_fail4():
# CutMixBatch Images # CutMixBatch Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
with pytest.raises(ValueError) as error: with pytest.raises(ValueError) as error:
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1) vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1)
@ -377,7 +423,7 @@ def test_cutmix_batch_fail5():
# CutMixBatch Images # CutMixBatch Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
data1 = data1.batch(5, drop_remainder=True) data1 = data1.batch(5, drop_remainder=True)
@ -405,7 +451,7 @@ def test_cutmix_batch_fail6():
# CutMixBatch Images # CutMixBatch Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW) cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
data1 = data1.batch(5, drop_remainder=True) data1 = data1.batch(5, drop_remainder=True)
@ -459,7 +505,7 @@ def test_cutmix_batch_fail8():
# CutMixBatch Images # CutMixBatch Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
with pytest.raises(ValueError) as error: with pytest.raises(ValueError) as error:
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.0) vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.0)
@ -467,6 +513,20 @@ def test_cutmix_batch_fail8():
assert error_message in str(error.value) assert error_message in str(error.value)
def test_cut_mix_batch_invalid_label_type():
"""
Feature: CutMixBatch
Description: Test CutMixBatch with label in str type
Expectation: Error is raised as expected
"""
image = np.random.randint(0, 255, (3, 28, 28, 1), dtype=np.uint8)
label = np.array([["one"], ["two"], ["three"]])
with pytest.raises(RuntimeError) as error:
_ = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)(image, label)
error_message = "CutMixBatch: invalid label type, label must be in a numeric type"
assert error_message in str(error.value)
if __name__ == "__main__": if __name__ == "__main__":
test_cutmix_batch_success1(plot=True) test_cutmix_batch_success1(plot=True)
test_cutmix_batch_success2(plot=True) test_cutmix_batch_success2(plot=True)

View File

@ -19,7 +19,7 @@ import numpy as np
import pytest import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.vision as vision import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as data_trans import mindspore.dataset.transforms as transforms
from mindspore import log as logger from mindspore import log as logger
from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
config_get_set_num_parallel_workers config_get_set_num_parallel_workers
@ -53,7 +53,7 @@ def test_mixup_batch_success1(plot=False):
# MixUp Images # MixUp Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
mixup_batch_op = vision.MixUpBatch(2) mixup_batch_op = vision.MixUpBatch(2)
data1 = data1.batch(5, drop_remainder=True) data1 = data1.batch(5, drop_remainder=True)
@ -102,7 +102,7 @@ def test_mixup_batch_success2(plot=False):
decode_op = vision.Decode() decode_op = vision.Decode()
data1 = data1.map(operations=[decode_op], input_columns=["image"]) data1 = data1.map(operations=[decode_op], input_columns=["image"])
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
mixup_batch_op = vision.MixUpBatch(2.0) mixup_batch_op = vision.MixUpBatch(2.0)
@ -147,7 +147,7 @@ def test_mixup_batch_success3(plot=False):
# MixUp Images # MixUp Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
mixup_batch_op = vision.MixUpBatch() mixup_batch_op = vision.MixUpBatch()
data1 = data1.batch(5, drop_remainder=True) data1 = data1.batch(5, drop_remainder=True)
@ -196,7 +196,7 @@ def test_mixup_batch_success4(plot=False):
decode_op = vision.Decode() decode_op = vision.Decode()
data1 = data1.map(operations=[decode_op], input_columns=["image"]) data1 = data1.map(operations=[decode_op], input_columns=["image"])
one_hot_op = data_trans.OneHot(num_classes=100) one_hot_op = transforms.OneHot(num_classes=100)
data1 = data1.map(operations=one_hot_op, input_columns=["attr"]) data1 = data1.map(operations=one_hot_op, input_columns=["attr"])
mixup_batch_op = vision.MixUpBatch() mixup_batch_op = vision.MixUpBatch()
@ -232,7 +232,7 @@ def test_mixup_batch_md5():
# MixUp Images # MixUp Images
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data = data.map(operations=one_hot_op, input_columns=["label"]) data = data.map(operations=one_hot_op, input_columns=["label"])
mixup_batch_op = vision.MixUpBatch() mixup_batch_op = vision.MixUpBatch()
data = data.batch(5, drop_remainder=True) data = data.batch(5, drop_remainder=True)
@ -246,6 +246,52 @@ def test_mixup_batch_md5():
ds.config.set_num_parallel_workers(original_num_parallel_workers) ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_mixup_batch_float_label():
"""
Feature: MixUpBatch
Description: Test MixUpBatch with label in type of float
Expectation: Output is as expected
"""
original_seed = config_get_set_seed(0)
image = np.random.randint(0, 255, (3, 28, 28, 1), dtype=np.uint8)
label = np.random.randint(0, 5, (3, 1))
decode_label = transforms.OneHot(5)(label)
float_label = transforms.TypeCast(float)(decode_label)
_, mix_label = vision.MixUpBatch()(image, float_label)
expected_label = np.array([[0., 0.6824126, 0., 0., 0.3175874],
[0., 0., 0., 0., 1.],
[0., 0.3175874, 0., 0., 0.6824126]])
np.testing.assert_almost_equal(mix_label, expected_label)
ds.config.set_seed(original_seed)
def test_mixup_batch_twice():
"""
Feature: MixUpBatch
Description: Test MixUpBatch called twice
Expectation: Output is as expected
"""
original_seed = config_get_set_seed(1)
dataset = ds.Cifar10Dataset(DATA_DIR, num_samples=3, shuffle=False)
one_hot = transforms.OneHot(num_classes=10)
dataset = dataset.map(operations=one_hot, input_columns=["label"])
mix_up_batch = vision.MixUpBatch()
dataset = dataset.batch(3, drop_remainder=False)
dataset = dataset.map(operations=mix_up_batch, input_columns=["image", "label"])
dataset = dataset.map(operations=mix_up_batch, input_columns=["image", "label"])
expected_label = np.array([[0.29373336, 0.49647674, 0.20978989, 0., 0., 0., 0., 0., 0., 0.],
[0.20978989, 0.29373336, 0.49647674, 0., 0., 0., 0., 0., 0., 0.],
[0.49647674, 0.20978989, 0.29373336, 0., 0., 0., 0., 0., 0., 0.]])
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
np.testing.assert_almost_equal(item["label"], expected_label)
ds.config.set_seed(original_seed)
def test_mixup_batch_fail1(): def test_mixup_batch_fail1():
""" """
Feature: MixUpBatch op Feature: MixUpBatch op
@ -268,7 +314,7 @@ def test_mixup_batch_fail1():
# MixUp Images # MixUp Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
mixup_batch_op = vision.MixUpBatch(0.1) mixup_batch_op = vision.MixUpBatch(0.1)
with pytest.raises(RuntimeError) as error: with pytest.raises(RuntimeError) as error:
@ -304,7 +350,7 @@ def test_mixup_batch_fail2():
# MixUp Images # MixUp Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
with pytest.raises(ValueError) as error: with pytest.raises(ValueError) as error:
vision.MixUpBatch(-1) vision.MixUpBatch(-1)
@ -333,7 +379,7 @@ def test_mixup_batch_fail3():
# MixUp Images # MixUp Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
mixup_batch_op = vision.MixUpBatch() mixup_batch_op = vision.MixUpBatch()
data1 = data1.batch(5, drop_remainder=True) data1 = data1.batch(5, drop_remainder=True)
@ -372,7 +418,7 @@ def test_mixup_batch_fail4():
# MixUp Images # MixUp Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = transforms.OneHot(num_classes=10)
data1 = data1.map(operations=one_hot_op, input_columns=["label"]) data1 = data1.map(operations=one_hot_op, input_columns=["label"])
with pytest.raises(ValueError) as error: with pytest.raises(ValueError) as error:
vision.MixUpBatch(0.0) vision.MixUpBatch(0.0)
@ -417,6 +463,20 @@ def test_mixup_batch_fail5():
assert error_message in str(error.value) assert error_message in str(error.value)
def test_mix_up_batch_invalid_label_type():
"""
Feature: MixUpBatch
Description: Test MixUpBatch with label in str type
Expectation: Error is raised as expected
"""
image = np.random.randint(0, 255, (3, 28, 28, 1), dtype=np.uint8)
label = np.array([["one"], ["two"], ["three"]])
with pytest.raises(RuntimeError) as error:
_ = vision.MixUpBatch()(image, label)
error_message = "MixUpBatch: invalid label type, label must be in a numeric type"
assert error_message in str(error.value)
if __name__ == "__main__": if __name__ == "__main__":
test_mixup_batch_success1(plot=True) test_mixup_batch_success1(plot=True)
test_mixup_batch_success2(plot=True) test_mixup_batch_success2(plot=True)