forked from mindspore-Ecosystem/mindspore
!4955 Fixes for Dynamic Augmentation Ops
Merge pull request !4955 from MahdiRahmaniHanzaki/dynamic-ops-fix
This commit is contained in:
commit
9b503e4f38
|
@ -382,7 +382,7 @@ CutMixBatchOperation::CutMixBatchOperation(ImageBatchFormat image_batch_format,
|
|||
: image_batch_format_(image_batch_format), alpha_(alpha), prob_(prob) {}
|
||||
|
||||
bool CutMixBatchOperation::ValidateParams() {
|
||||
if (alpha_ < 0) {
|
||||
if (alpha_ <= 0) {
|
||||
MS_LOG(ERROR) << "CutMixBatch: alpha cannot be negative.";
|
||||
return false;
|
||||
}
|
||||
|
@ -434,7 +434,7 @@ std::shared_ptr<TensorOp> HwcToChwOperation::Build() { return std::make_shared<H
|
|||
MixUpBatchOperation::MixUpBatchOperation(float alpha) : alpha_(alpha) {}
|
||||
|
||||
bool MixUpBatchOperation::ValidateParams() {
|
||||
if (alpha_ < 0) {
|
||||
if (alpha_ <= 0) {
|
||||
MS_LOG(ERROR) << "MixUpBatch: alpha must be a positive floating value however it is: " << alpha_;
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
|
||||
// Check inputs
|
||||
if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) {
|
||||
RETURN_STATUS_UNEXPECTED("You must batch before calling CutMixBatch.");
|
||||
RETURN_STATUS_UNEXPECTED("You must make sure images are HWC or CHW and batch before calling CutMixBatch.");
|
||||
}
|
||||
if (label_shape.size() != 2) {
|
||||
RETURN_STATUS_UNEXPECTED("CutMixBatch: Label's must be in one-hot format and in a batch");
|
||||
|
@ -139,10 +139,17 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
|
||||
// Compute labels
|
||||
for (int j = 0; j < label_shape[1]; j++) {
|
||||
uint64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j}));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i] % label_shape[0], j}));
|
||||
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, label_lam * first_value + (1 - label_lam) * second_value));
|
||||
if (input.at(1)->type().IsSignedInt()) {
|
||||
int64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j}));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i] % label_shape[0], j}));
|
||||
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, label_lam * first_value + (1 - label_lam) * second_value));
|
||||
} else {
|
||||
uint64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j}));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i] % label_shape[0], j}));
|
||||
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, label_lam * first_value + (1 - label_lam) * second_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
|
||||
// Check inputs
|
||||
if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) {
|
||||
RETURN_STATUS_UNEXPECTED("You must batch before calling MixUpBatch");
|
||||
RETURN_STATUS_UNEXPECTED("You must make sure images are HWC or CHW and batch before calling MixUpBatch");
|
||||
}
|
||||
if (label_shape.size() != 2) {
|
||||
RETURN_STATUS_UNEXPECTED("MixUpBatch: Label's must be in one-hot format and in a batch");
|
||||
|
@ -68,10 +68,17 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), &out_labels, DataType("float32")));
|
||||
for (int64_t i = 0; i < label_shape[0]; i++) {
|
||||
for (int64_t j = 0; j < label_shape[1]; j++) {
|
||||
uint64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j}));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i], j}));
|
||||
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, lam * first_value + (1 - lam) * second_value));
|
||||
if (input.at(1)->type().IsSignedInt()) {
|
||||
int64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j}));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i], j}));
|
||||
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, lam * first_value + (1 - lam) * second_value));
|
||||
} else {
|
||||
uint64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j}));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i], j}));
|
||||
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, lam * first_value + (1 - lam) * second_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -231,7 +231,7 @@ class Normalize(cde.NormalizeOp):
|
|||
|
||||
class RandomAffine(cde.RandomAffineOp):
|
||||
"""
|
||||
Apply Random affine transformation to the input PIL image.
|
||||
Apply Random affine transformation to the input image.
|
||||
|
||||
Args:
|
||||
degrees (int or float or sequence): Range of the rotation degrees.
|
||||
|
@ -681,12 +681,12 @@ class CenterCrop(cde.CenterCropOp):
|
|||
class RandomColor(cde.RandomColorOp):
|
||||
"""
|
||||
Adjust the color of the input image by a fixed or random degree.
|
||||
This operation works only with 3-channel color images.
|
||||
|
||||
Args:
|
||||
degrees (sequence): Range of random color adjustment degrees.
|
||||
It should be in (min, max) format. If min=max, then it is a
|
||||
single fixed magnitude operation (default=(0.1,1.9)).
|
||||
Works with 3-channel color images.
|
||||
"""
|
||||
|
||||
@check_positive_degrees
|
||||
|
|
|
@ -1169,39 +1169,12 @@ class RandomAffine:
|
|||
Returns:
|
||||
img (PIL Image), Randomly affine transformed image.
|
||||
"""
|
||||
# rotation
|
||||
angle = random.uniform(self.degrees[0], self.degrees[1])
|
||||
|
||||
# translation
|
||||
if self.translate is not None:
|
||||
max_dx = self.translate[0] * img.size[0]
|
||||
max_dy = self.translate[1] * img.size[1]
|
||||
translations = (np.round(random.uniform(-max_dx, max_dx)),
|
||||
np.round(random.uniform(-max_dy, max_dy)))
|
||||
else:
|
||||
translations = (0, 0)
|
||||
|
||||
# scale
|
||||
if self.scale_ranges is not None:
|
||||
scale = random.uniform(self.scale_ranges[0], self.scale_ranges[1])
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
# shear
|
||||
if self.shear is not None:
|
||||
if len(self.shear) == 2:
|
||||
shear = [random.uniform(self.shear[0], self.shear[1]), 0.]
|
||||
elif len(self.shear) == 4:
|
||||
shear = [random.uniform(self.shear[0], self.shear[1]),
|
||||
random.uniform(self.shear[2], self.shear[3])]
|
||||
else:
|
||||
shear = 0.0
|
||||
|
||||
return util.random_affine(img,
|
||||
angle,
|
||||
translations,
|
||||
scale,
|
||||
shear,
|
||||
self.degrees,
|
||||
self.translate,
|
||||
self.scale_ranges,
|
||||
self.shear,
|
||||
self.resample,
|
||||
self.fill_value)
|
||||
|
||||
|
|
|
@ -1153,6 +1153,34 @@ def random_affine(img, angle, translations, scale, shear, resample, fill_value=0
|
|||
if not is_pil(img):
|
||||
raise ValueError("Input image should be a Pillow image.")
|
||||
|
||||
# rotation
|
||||
angle = random.uniform(angle[0], angle[1])
|
||||
|
||||
# translation
|
||||
if translations is not None:
|
||||
max_dx = translations[0] * img.size[0]
|
||||
max_dy = translations[1] * img.size[1]
|
||||
translations = (np.round(random.uniform(-max_dx, max_dx)),
|
||||
np.round(random.uniform(-max_dy, max_dy)))
|
||||
else:
|
||||
translations = (0, 0)
|
||||
|
||||
# scale
|
||||
if scale is not None:
|
||||
scale = random.uniform(scale[0], scale[1])
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
# shear
|
||||
if shear is not None:
|
||||
if len(shear) == 2:
|
||||
shear = [random.uniform(shear[0], shear[1]), 0.]
|
||||
elif len(shear) == 4:
|
||||
shear = [random.uniform(shear[0], shear[1]),
|
||||
random.uniform(shear[2], shear[3])]
|
||||
else:
|
||||
shear = 0.0
|
||||
|
||||
output_size = img.size
|
||||
center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)
|
||||
|
||||
|
@ -1416,7 +1444,6 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
|
|||
|
||||
|
||||
def random_color(img, degrees):
|
||||
|
||||
"""
|
||||
Adjust the color of the input PIL image by a random degree.
|
||||
|
||||
|
@ -1437,7 +1464,6 @@ def random_color(img, degrees):
|
|||
|
||||
|
||||
def random_sharpness(img, degrees):
|
||||
|
||||
"""
|
||||
Adjust the sharpness of the input PIL image by a random degree.
|
||||
|
||||
|
@ -1458,7 +1484,6 @@ def random_sharpness(img, degrees):
|
|||
|
||||
|
||||
def auto_contrast(img, cutoff, ignore):
|
||||
|
||||
"""
|
||||
Automatically maximize the contrast of the input PIL image.
|
||||
|
||||
|
@ -1479,7 +1504,6 @@ def auto_contrast(img, cutoff, ignore):
|
|||
|
||||
|
||||
def invert_color(img):
|
||||
|
||||
"""
|
||||
Invert colors of input PIL image.
|
||||
|
||||
|
@ -1498,7 +1522,6 @@ def invert_color(img):
|
|||
|
||||
|
||||
def equalize(img):
|
||||
|
||||
"""
|
||||
Equalize the histogram of input PIL image.
|
||||
|
||||
|
@ -1517,7 +1540,6 @@ def equalize(img):
|
|||
|
||||
|
||||
def uniform_augment(img, transforms, num_ops):
|
||||
|
||||
"""
|
||||
Uniformly select and apply a number of transforms sequentially from
|
||||
a list of transforms. Randomly assigns a probability to each transform for
|
||||
|
|
|
@ -45,6 +45,7 @@ def check_cut_mix_batch_c(method):
|
|||
[image_batch_format, alpha, prob], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(image_batch_format, (ImageBatchFormat,), "image_batch_format")
|
||||
check_pos_float32(alpha)
|
||||
check_positive(alpha, "alpha")
|
||||
check_value(prob, [0, 1], "prob")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
|
@ -68,6 +69,7 @@ def check_mix_up_batch_c(method):
|
|||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[alpha], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_positive(alpha, "alpha")
|
||||
check_pos_float32(alpha)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
|
|
@ -191,11 +191,37 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) {
|
|||
ds = ds->Map({one_hot_op},{"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, 1, -0.5);
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC,
|
||||
1, -0.5);
|
||||
EXPECT_EQ(cutmix_batch_op, nullptr);
|
||||
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCutMixBatchFail3) {
|
||||
// Must fail because alpha can't be zero
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 5;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(10);
|
||||
EXPECT_NE(one_hot_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({one_hot_op},{"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC,
|
||||
0.0, 0.5);
|
||||
EXPECT_EQ(cutmix_batch_op, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCutOut) {
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
@ -365,6 +391,30 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) {
|
|||
EXPECT_EQ(mixup_batch_op, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestMixUpBatchFail2) {
|
||||
// This should fail because alpha can't be zero
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 5;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(10);
|
||||
EXPECT_NE(one_hot_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
ds = ds->Map({one_hot_op}, {"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(0.0);
|
||||
EXPECT_EQ(mixup_batch_op, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
|
@ -384,7 +434,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
|
|||
ds = ds->Map({one_hot_op}, {"label"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(0.5);
|
||||
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(2.0);
|
||||
EXPECT_NE(mixup_batch_op, nullptr);
|
||||
|
||||
// Create a Map operation on ds
|
||||
|
|
|
@ -26,6 +26,7 @@ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_se
|
|||
config_get_set_num_parallel_workers
|
||||
|
||||
DATA_DIR = "../data/dataset/testCifar10Data"
|
||||
DATA_DIR2 = "../data/dataset/testImageNetData2/train/"
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
|
@ -114,6 +115,53 @@ def test_cutmix_batch_success2(plot=False):
|
|||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
|
||||
def test_cutmix_batch_success3(plot=False):
|
||||
"""
|
||||
Test CutMixBatch op with default values for alpha and prob on a batch of HWC images on ImageFolderDatasetV2
|
||||
"""
|
||||
logger.info("test_cutmix_batch_success3")
|
||||
|
||||
ds_original = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False)
|
||||
decode_op = vision.Decode()
|
||||
ds_original = ds_original.map(input_columns=["image"], operations=[decode_op])
|
||||
ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True)
|
||||
|
||||
images_original = None
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = image
|
||||
else:
|
||||
images_original = np.append(images_original, image, axis=0)
|
||||
|
||||
# CutMix Images
|
||||
data1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False)
|
||||
|
||||
decode_op = vision.Decode()
|
||||
data1 = data1.map(input_columns=["image"], operations=[decode_op])
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||
data1 = data1.batch(4, pad_info={}, drop_remainder=True)
|
||||
data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
|
||||
|
||||
images_cutmix = None
|
||||
for idx, (image, _) in enumerate(data1):
|
||||
if idx == 0:
|
||||
images_cutmix = image
|
||||
else:
|
||||
images_cutmix = np.append(images_cutmix, image, axis=0)
|
||||
if plot:
|
||||
visualize_list(images_original, images_cutmix)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_cutmix[i], images_original[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
|
||||
def test_cutmix_batch_nhwc_md5():
|
||||
"""
|
||||
Test CutMixBatch on a batch of HWC images with MD5:
|
||||
|
@ -185,7 +233,7 @@ def test_cutmix_batch_fail1():
|
|||
images_cutmix = image
|
||||
else:
|
||||
images_cutmix = np.append(images_cutmix, image, axis=0)
|
||||
error_message = "You must batch before calling CutMixBatch"
|
||||
error_message = "You must make sure images are HWC or CHW and batch "
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
|
@ -322,9 +370,28 @@ def test_cutmix_batch_fail7():
|
|||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
def test_cutmix_batch_fail8():
|
||||
"""
|
||||
Test CutMixBatch Fail 8
|
||||
We expect this to fail because alpha is zero
|
||||
"""
|
||||
logger.info("test_cutmix_batch_fail8")
|
||||
|
||||
# CutMixBatch Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
with pytest.raises(ValueError) as error:
|
||||
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.0)
|
||||
error_message = "Input is not within the required interval"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cutmix_batch_success1(plot=True)
|
||||
test_cutmix_batch_success2(plot=True)
|
||||
test_cutmix_batch_success3(plot=True)
|
||||
test_cutmix_batch_nchw_md5()
|
||||
test_cutmix_batch_nhwc_md5()
|
||||
test_cutmix_batch_fail1()
|
||||
|
@ -334,3 +401,4 @@ if __name__ == "__main__":
|
|||
test_cutmix_batch_fail5()
|
||||
test_cutmix_batch_fail6()
|
||||
test_cutmix_batch_fail7()
|
||||
test_cutmix_batch_fail8()
|
||||
|
|
|
@ -25,6 +25,7 @@ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_se
|
|||
config_get_set_num_parallel_workers
|
||||
|
||||
DATA_DIR = "../data/dataset/testCifar10Data"
|
||||
DATA_DIR2 = "../data/dataset/testImageNetData2/train/"
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
|
@ -71,11 +72,59 @@ def test_mixup_batch_success1(plot=False):
|
|||
|
||||
|
||||
def test_mixup_batch_success2(plot=False):
|
||||
"""
|
||||
Test MixUpBatch op with specified alpha parameter on ImageFolderDatasetV2
|
||||
"""
|
||||
logger.info("test_mixup_batch_success2")
|
||||
|
||||
# Original Images
|
||||
ds_original = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False)
|
||||
decode_op = vision.Decode()
|
||||
ds_original = ds_original.map(input_columns=["image"], operations=[decode_op])
|
||||
ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True)
|
||||
|
||||
images_original = None
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = image
|
||||
else:
|
||||
images_original = np.append(images_original, image, axis=0)
|
||||
|
||||
# MixUp Images
|
||||
data1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False)
|
||||
|
||||
decode_op = vision.Decode()
|
||||
data1 = data1.map(input_columns=["image"], operations=[decode_op])
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
|
||||
mixup_batch_op = vision.MixUpBatch(2.0)
|
||||
data1 = data1.batch(4, pad_info={}, drop_remainder=True)
|
||||
data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
|
||||
|
||||
images_mixup = None
|
||||
for idx, (image, _) in enumerate(data1):
|
||||
if idx == 0:
|
||||
images_mixup = image
|
||||
else:
|
||||
images_mixup = np.append(images_mixup, image, axis=0)
|
||||
if plot:
|
||||
visualize_list(images_original, images_mixup)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_mixup[i], images_original[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
|
||||
def test_mixup_batch_success3(plot=False):
|
||||
"""
|
||||
Test MixUpBatch op without specified alpha parameter.
|
||||
Alpha parameter will be selected by default in this case
|
||||
"""
|
||||
logger.info("test_mixup_batch_success2")
|
||||
logger.info("test_mixup_batch_success3")
|
||||
|
||||
# Original Images
|
||||
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
@ -169,7 +218,7 @@ def test_mixup_batch_fail1():
|
|||
images_mixup = image
|
||||
else:
|
||||
images_mixup = np.append(images_mixup, image, axis=0)
|
||||
error_message = "You must batch before calling MixUp"
|
||||
error_message = "You must make sure images are HWC or CHW and batch"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
|
@ -207,6 +256,7 @@ def test_mixup_batch_fail3():
|
|||
Test MixUpBatch op
|
||||
We expect this to fail because label column is not passed to mixup_batch
|
||||
"""
|
||||
logger.info("test_mixup_batch_fail3")
|
||||
# Original Images
|
||||
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
ds_original = ds_original.batch(5, drop_remainder=True)
|
||||
|
@ -237,11 +287,41 @@ def test_mixup_batch_fail3():
|
|||
error_message = "Both images and labels columns are required"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
def test_mixup_batch_fail4():
|
||||
"""
|
||||
Test MixUpBatch Fail 2
|
||||
We expect this to fail because alpha is zero
|
||||
"""
|
||||
logger.info("test_mixup_batch_fail4")
|
||||
|
||||
# Original Images
|
||||
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
ds_original = ds_original.batch(5)
|
||||
|
||||
images_original = np.array([])
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = image
|
||||
else:
|
||||
images_original = np.append(images_original, image, axis=0)
|
||||
|
||||
# MixUp Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
with pytest.raises(ValueError) as error:
|
||||
vision.MixUpBatch(0.0)
|
||||
error_message = "Input is not within the required interval"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mixup_batch_success1(plot=True)
|
||||
test_mixup_batch_success2(plot=True)
|
||||
test_mixup_batch_success3(plot=True)
|
||||
test_mixup_batch_md5()
|
||||
test_mixup_batch_fail1()
|
||||
test_mixup_batch_fail2()
|
||||
test_mixup_batch_fail3()
|
||||
test_mixup_batch_fail4()
|
||||
|
|
|
@ -27,6 +27,7 @@ GENERATE_GOLDEN = False
|
|||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
MNIST_DATA_DIR = "../data/dataset/testMnistData"
|
||||
|
||||
|
||||
def test_random_affine_op(plot=False):
|
||||
|
@ -155,6 +156,24 @@ def test_random_affine_c_md5():
|
|||
ds.config.set_num_parallel_workers((original_num_parallel_workers))
|
||||
|
||||
|
||||
def test_random_affine_py_exception_non_pil_images():
|
||||
"""
|
||||
Test RandomAffine: input img is ndarray and not PIL, expected to raise TypeError
|
||||
"""
|
||||
logger.info("test_random_affine_exception_negative_degrees")
|
||||
dataset = ds.MnistDataset(MNIST_DATA_DIR, num_parallel_workers=3)
|
||||
try:
|
||||
transform = py_vision.ComposeOp([py_vision.ToTensor(),
|
||||
py_vision.RandomAffine(degrees=(15, 15))])
|
||||
dataset = dataset.map(input_columns=["image"], operations=transform(), num_parallel_workers=3,
|
||||
python_multiprocessing=True)
|
||||
for _ in dataset.create_dict_iterator():
|
||||
break
|
||||
except RuntimeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Pillow image" in str(e)
|
||||
|
||||
|
||||
def test_random_affine_exception_negative_degrees():
|
||||
"""
|
||||
Test RandomAffine: input degrees in negative, expected to raise ValueError
|
||||
|
@ -289,6 +308,7 @@ if __name__ == "__main__":
|
|||
test_random_affine_op_c(plot=True)
|
||||
test_random_affine_md5()
|
||||
test_random_affine_c_md5()
|
||||
test_random_affine_py_exception_non_pil_images()
|
||||
test_random_affine_exception_negative_degrees()
|
||||
test_random_affine_exception_translation_range()
|
||||
test_random_affine_exception_scale_value()
|
||||
|
|
Loading…
Reference in New Issue