!4955 Fixes for Dynamic Augmentation Ops

Merge pull request !4955 from MahdiRahmaniHanzaki/dynamic-ops-fix
This commit is contained in:
mindspore-ci-bot 2020-08-22 10:27:50 +08:00 committed by Gitee
commit 9b503e4f38
11 changed files with 285 additions and 56 deletions

View File

@ -382,7 +382,7 @@ CutMixBatchOperation::CutMixBatchOperation(ImageBatchFormat image_batch_format,
: image_batch_format_(image_batch_format), alpha_(alpha), prob_(prob) {} : image_batch_format_(image_batch_format), alpha_(alpha), prob_(prob) {}
bool CutMixBatchOperation::ValidateParams() { bool CutMixBatchOperation::ValidateParams() {
if (alpha_ < 0) { if (alpha_ <= 0) {
MS_LOG(ERROR) << "CutMixBatch: alpha cannot be negative."; MS_LOG(ERROR) << "CutMixBatch: alpha cannot be negative.";
return false; return false;
} }
@ -434,7 +434,7 @@ std::shared_ptr<TensorOp> HwcToChwOperation::Build() { return std::make_shared<H
MixUpBatchOperation::MixUpBatchOperation(float alpha) : alpha_(alpha) {} MixUpBatchOperation::MixUpBatchOperation(float alpha) : alpha_(alpha) {}
bool MixUpBatchOperation::ValidateParams() { bool MixUpBatchOperation::ValidateParams() {
if (alpha_ < 0) { if (alpha_ <= 0) {
MS_LOG(ERROR) << "MixUpBatch: alpha must be a positive floating value however it is: " << alpha_; MS_LOG(ERROR) << "MixUpBatch: alpha must be a positive floating value however it is: " << alpha_;
return false; return false;
} }

View File

@ -59,7 +59,7 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
// Check inputs // Check inputs
if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) { if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) {
RETURN_STATUS_UNEXPECTED("You must batch before calling CutMixBatch."); RETURN_STATUS_UNEXPECTED("You must make sure images are HWC or CHW and batch before calling CutMixBatch.");
} }
if (label_shape.size() != 2) { if (label_shape.size() != 2) {
RETURN_STATUS_UNEXPECTED("CutMixBatch: Label's must be in one-hot format and in a batch"); RETURN_STATUS_UNEXPECTED("CutMixBatch: Label's must be in one-hot format and in a batch");
@ -139,10 +139,17 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
// Compute labels // Compute labels
for (int j = 0; j < label_shape[1]; j++) { for (int j = 0; j < label_shape[1]; j++) {
uint64_t first_value, second_value; if (input.at(1)->type().IsSignedInt()) {
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j})); int64_t first_value, second_value;
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i] % label_shape[0], j})); RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j}));
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, label_lam * first_value + (1 - label_lam) * second_value)); RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i] % label_shape[0], j}));
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, label_lam * first_value + (1 - label_lam) * second_value));
} else {
uint64_t first_value, second_value;
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j}));
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i] % label_shape[0], j}));
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, label_lam * first_value + (1 - label_lam) * second_value));
}
} }
} }
} }

View File

@ -38,7 +38,7 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
// Check inputs // Check inputs
if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) { if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) {
RETURN_STATUS_UNEXPECTED("You must batch before calling MixUpBatch"); RETURN_STATUS_UNEXPECTED("You must make sure images are HWC or CHW and batch before calling MixUpBatch");
} }
if (label_shape.size() != 2) { if (label_shape.size() != 2) {
RETURN_STATUS_UNEXPECTED("MixUpBatch: Label's must be in one-hot format and in a batch"); RETURN_STATUS_UNEXPECTED("MixUpBatch: Label's must be in one-hot format and in a batch");
@ -68,10 +68,17 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), &out_labels, DataType("float32"))); RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), &out_labels, DataType("float32")));
for (int64_t i = 0; i < label_shape[0]; i++) { for (int64_t i = 0; i < label_shape[0]; i++) {
for (int64_t j = 0; j < label_shape[1]; j++) { for (int64_t j = 0; j < label_shape[1]; j++) {
uint64_t first_value, second_value; if (input.at(1)->type().IsSignedInt()) {
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j})); int64_t first_value, second_value;
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i], j})); RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j}));
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, lam * first_value + (1 - lam) * second_value)); RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i], j}));
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, lam * first_value + (1 - lam) * second_value));
} else {
uint64_t first_value, second_value;
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j}));
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i], j}));
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, lam * first_value + (1 - lam) * second_value));
}
} }
} }

View File

@ -231,7 +231,7 @@ class Normalize(cde.NormalizeOp):
class RandomAffine(cde.RandomAffineOp): class RandomAffine(cde.RandomAffineOp):
""" """
Apply Random affine transformation to the input PIL image. Apply Random affine transformation to the input image.
Args: Args:
degrees (int or float or sequence): Range of the rotation degrees. degrees (int or float or sequence): Range of the rotation degrees.
@ -681,12 +681,12 @@ class CenterCrop(cde.CenterCropOp):
class RandomColor(cde.RandomColorOp): class RandomColor(cde.RandomColorOp):
""" """
Adjust the color of the input image by a fixed or random degree. Adjust the color of the input image by a fixed or random degree.
This operation works only with 3-channel color images.
Args: Args:
degrees (sequence): Range of random color adjustment degrees. degrees (sequence): Range of random color adjustment degrees.
It should be in (min, max) format. If min=max, then it is a It should be in (min, max) format. If min=max, then it is a
single fixed magnitude operation (default=(0.1,1.9)). single fixed magnitude operation (default=(0.1,1.9)).
Works with 3-channel color images.
""" """
@check_positive_degrees @check_positive_degrees

View File

@ -1169,39 +1169,12 @@ class RandomAffine:
Returns: Returns:
img (PIL Image), Randomly affine transformed image. img (PIL Image), Randomly affine transformed image.
""" """
# rotation
angle = random.uniform(self.degrees[0], self.degrees[1])
# translation
if self.translate is not None:
max_dx = self.translate[0] * img.size[0]
max_dy = self.translate[1] * img.size[1]
translations = (np.round(random.uniform(-max_dx, max_dx)),
np.round(random.uniform(-max_dy, max_dy)))
else:
translations = (0, 0)
# scale
if self.scale_ranges is not None:
scale = random.uniform(self.scale_ranges[0], self.scale_ranges[1])
else:
scale = 1.0
# shear
if self.shear is not None:
if len(self.shear) == 2:
shear = [random.uniform(self.shear[0], self.shear[1]), 0.]
elif len(self.shear) == 4:
shear = [random.uniform(self.shear[0], self.shear[1]),
random.uniform(self.shear[2], self.shear[3])]
else:
shear = 0.0
return util.random_affine(img, return util.random_affine(img,
angle, self.degrees,
translations, self.translate,
scale, self.scale_ranges,
shear, self.shear,
self.resample, self.resample,
self.fill_value) self.fill_value)

View File

@ -1153,6 +1153,34 @@ def random_affine(img, angle, translations, scale, shear, resample, fill_value=0
if not is_pil(img): if not is_pil(img):
raise ValueError("Input image should be a Pillow image.") raise ValueError("Input image should be a Pillow image.")
# rotation
angle = random.uniform(angle[0], angle[1])
# translation
if translations is not None:
max_dx = translations[0] * img.size[0]
max_dy = translations[1] * img.size[1]
translations = (np.round(random.uniform(-max_dx, max_dx)),
np.round(random.uniform(-max_dy, max_dy)))
else:
translations = (0, 0)
# scale
if scale is not None:
scale = random.uniform(scale[0], scale[1])
else:
scale = 1.0
# shear
if shear is not None:
if len(shear) == 2:
shear = [random.uniform(shear[0], shear[1]), 0.]
elif len(shear) == 4:
shear = [random.uniform(shear[0], shear[1]),
random.uniform(shear[2], shear[3])]
else:
shear = 0.0
output_size = img.size output_size = img.size
center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5) center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)
@ -1416,7 +1444,6 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
def random_color(img, degrees): def random_color(img, degrees):
""" """
Adjust the color of the input PIL image by a random degree. Adjust the color of the input PIL image by a random degree.
@ -1437,7 +1464,6 @@ def random_color(img, degrees):
def random_sharpness(img, degrees): def random_sharpness(img, degrees):
""" """
Adjust the sharpness of the input PIL image by a random degree. Adjust the sharpness of the input PIL image by a random degree.
@ -1458,7 +1484,6 @@ def random_sharpness(img, degrees):
def auto_contrast(img, cutoff, ignore): def auto_contrast(img, cutoff, ignore):
""" """
Automatically maximize the contrast of the input PIL image. Automatically maximize the contrast of the input PIL image.
@ -1479,7 +1504,6 @@ def auto_contrast(img, cutoff, ignore):
def invert_color(img): def invert_color(img):
""" """
Invert colors of input PIL image. Invert colors of input PIL image.
@ -1498,7 +1522,6 @@ def invert_color(img):
def equalize(img): def equalize(img):
""" """
Equalize the histogram of input PIL image. Equalize the histogram of input PIL image.
@ -1517,7 +1540,6 @@ def equalize(img):
def uniform_augment(img, transforms, num_ops): def uniform_augment(img, transforms, num_ops):
""" """
Uniformly select and apply a number of transforms sequentially from Uniformly select and apply a number of transforms sequentially from
a list of transforms. Randomly assigns a probability to each transform for a list of transforms. Randomly assigns a probability to each transform for

View File

@ -45,6 +45,7 @@ def check_cut_mix_batch_c(method):
[image_batch_format, alpha, prob], _ = parse_user_args(method, *args, **kwargs) [image_batch_format, alpha, prob], _ = parse_user_args(method, *args, **kwargs)
type_check(image_batch_format, (ImageBatchFormat,), "image_batch_format") type_check(image_batch_format, (ImageBatchFormat,), "image_batch_format")
check_pos_float32(alpha) check_pos_float32(alpha)
check_positive(alpha, "alpha")
check_value(prob, [0, 1], "prob") check_value(prob, [0, 1], "prob")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
@ -68,6 +69,7 @@ def check_mix_up_batch_c(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
[alpha], _ = parse_user_args(method, *args, **kwargs) [alpha], _ = parse_user_args(method, *args, **kwargs)
check_positive(alpha, "alpha")
check_pos_float32(alpha) check_pos_float32(alpha)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)

View File

@ -191,11 +191,37 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) {
ds = ds->Map({one_hot_op},{"label"}); ds = ds->Map({one_hot_op},{"label"});
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, 1, -0.5); std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC,
1, -0.5);
EXPECT_EQ(cutmix_batch_op, nullptr); EXPECT_EQ(cutmix_batch_op, nullptr);
} }
TEST_F(MindDataTestPipeline, TestCutMixBatchFail3) {
// Must fail because alpha can't be zero
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 5;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(10);
EXPECT_NE(one_hot_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({one_hot_op},{"label"});
EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC,
0.0, 0.5);
EXPECT_EQ(cutmix_batch_op, nullptr);
}
TEST_F(MindDataTestPipeline, TestCutOut) { TEST_F(MindDataTestPipeline, TestCutOut) {
// Create an ImageFolder Dataset // Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/"; std::string folder_path = datasets_root_path_ + "/testPK/data/";
@ -365,6 +391,30 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) {
EXPECT_EQ(mixup_batch_op, nullptr); EXPECT_EQ(mixup_batch_op, nullptr);
} }
TEST_F(MindDataTestPipeline, TestMixUpBatchFail2) {
// This should fail because alpha can't be zero
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 5;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(10);
EXPECT_NE(one_hot_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({one_hot_op}, {"label"});
EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(0.0);
EXPECT_EQ(mixup_batch_op, nullptr);
}
TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) { TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
// Create a Cifar10 Dataset // Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
@ -384,7 +434,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
ds = ds->Map({one_hot_op}, {"label"}); ds = ds->Map({one_hot_op}, {"label"});
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(0.5); std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(2.0);
EXPECT_NE(mixup_batch_op, nullptr); EXPECT_NE(mixup_batch_op, nullptr);
// Create a Map operation on ds // Create a Map operation on ds

View File

@ -26,6 +26,7 @@ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_se
config_get_set_num_parallel_workers config_get_set_num_parallel_workers
DATA_DIR = "../data/dataset/testCifar10Data" DATA_DIR = "../data/dataset/testCifar10Data"
DATA_DIR2 = "../data/dataset/testImageNetData2/train/"
GENERATE_GOLDEN = False GENERATE_GOLDEN = False
@ -114,6 +115,53 @@ def test_cutmix_batch_success2(plot=False):
logger.info("MSE= {}".format(str(np.mean(mse)))) logger.info("MSE= {}".format(str(np.mean(mse))))
def test_cutmix_batch_success3(plot=False):
"""
Test CutMixBatch op with default values for alpha and prob on a batch of HWC images on ImageFolderDatasetV2
"""
logger.info("test_cutmix_batch_success3")
ds_original = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False)
decode_op = vision.Decode()
ds_original = ds_original.map(input_columns=["image"], operations=[decode_op])
ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True)
images_original = None
for idx, (image, _) in enumerate(ds_original):
if idx == 0:
images_original = image
else:
images_original = np.append(images_original, image, axis=0)
# CutMix Images
data1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False)
decode_op = vision.Decode()
data1 = data1.map(input_columns=["image"], operations=[decode_op])
one_hot_op = data_trans.OneHot(num_classes=10)
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
data1 = data1.batch(4, pad_info={}, drop_remainder=True)
data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
images_cutmix = None
for idx, (image, _) in enumerate(data1):
if idx == 0:
images_cutmix = image
else:
images_cutmix = np.append(images_cutmix, image, axis=0)
if plot:
visualize_list(images_original, images_cutmix)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_cutmix[i], images_original[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
def test_cutmix_batch_nhwc_md5(): def test_cutmix_batch_nhwc_md5():
""" """
Test CutMixBatch on a batch of HWC images with MD5: Test CutMixBatch on a batch of HWC images with MD5:
@ -185,7 +233,7 @@ def test_cutmix_batch_fail1():
images_cutmix = image images_cutmix = image
else: else:
images_cutmix = np.append(images_cutmix, image, axis=0) images_cutmix = np.append(images_cutmix, image, axis=0)
error_message = "You must batch before calling CutMixBatch" error_message = "You must make sure images are HWC or CHW and batch "
assert error_message in str(error.value) assert error_message in str(error.value)
@ -322,9 +370,28 @@ def test_cutmix_batch_fail7():
assert error_message in str(error.value) assert error_message in str(error.value)
def test_cutmix_batch_fail8():
"""
Test CutMixBatch Fail 8
We expect this to fail because alpha is zero
"""
logger.info("test_cutmix_batch_fail8")
# CutMixBatch Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10)
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
with pytest.raises(ValueError) as error:
vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.0)
error_message = "Input is not within the required interval"
assert error_message in str(error.value)
if __name__ == "__main__": if __name__ == "__main__":
test_cutmix_batch_success1(plot=True) test_cutmix_batch_success1(plot=True)
test_cutmix_batch_success2(plot=True) test_cutmix_batch_success2(plot=True)
test_cutmix_batch_success3(plot=True)
test_cutmix_batch_nchw_md5() test_cutmix_batch_nchw_md5()
test_cutmix_batch_nhwc_md5() test_cutmix_batch_nhwc_md5()
test_cutmix_batch_fail1() test_cutmix_batch_fail1()
@ -334,3 +401,4 @@ if __name__ == "__main__":
test_cutmix_batch_fail5() test_cutmix_batch_fail5()
test_cutmix_batch_fail6() test_cutmix_batch_fail6()
test_cutmix_batch_fail7() test_cutmix_batch_fail7()
test_cutmix_batch_fail8()

View File

@ -25,6 +25,7 @@ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_se
config_get_set_num_parallel_workers config_get_set_num_parallel_workers
DATA_DIR = "../data/dataset/testCifar10Data" DATA_DIR = "../data/dataset/testCifar10Data"
DATA_DIR2 = "../data/dataset/testImageNetData2/train/"
GENERATE_GOLDEN = False GENERATE_GOLDEN = False
@ -71,11 +72,59 @@ def test_mixup_batch_success1(plot=False):
def test_mixup_batch_success2(plot=False): def test_mixup_batch_success2(plot=False):
"""
Test MixUpBatch op with specified alpha parameter on ImageFolderDatasetV2
"""
logger.info("test_mixup_batch_success2")
# Original Images
ds_original = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False)
decode_op = vision.Decode()
ds_original = ds_original.map(input_columns=["image"], operations=[decode_op])
ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True)
images_original = None
for idx, (image, _) in enumerate(ds_original):
if idx == 0:
images_original = image
else:
images_original = np.append(images_original, image, axis=0)
# MixUp Images
data1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False)
decode_op = vision.Decode()
data1 = data1.map(input_columns=["image"], operations=[decode_op])
one_hot_op = data_trans.OneHot(num_classes=10)
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
mixup_batch_op = vision.MixUpBatch(2.0)
data1 = data1.batch(4, pad_info={}, drop_remainder=True)
data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
images_mixup = None
for idx, (image, _) in enumerate(data1):
if idx == 0:
images_mixup = image
else:
images_mixup = np.append(images_mixup, image, axis=0)
if plot:
visualize_list(images_original, images_mixup)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_mixup[i], images_original[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
def test_mixup_batch_success3(plot=False):
""" """
Test MixUpBatch op without specified alpha parameter. Test MixUpBatch op without specified alpha parameter.
Alpha parameter will be selected by default in this case Alpha parameter will be selected by default in this case
""" """
logger.info("test_mixup_batch_success2") logger.info("test_mixup_batch_success3")
# Original Images # Original Images
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
@ -169,7 +218,7 @@ def test_mixup_batch_fail1():
images_mixup = image images_mixup = image
else: else:
images_mixup = np.append(images_mixup, image, axis=0) images_mixup = np.append(images_mixup, image, axis=0)
error_message = "You must batch before calling MixUp" error_message = "You must make sure images are HWC or CHW and batch"
assert error_message in str(error.value) assert error_message in str(error.value)
@ -207,6 +256,7 @@ def test_mixup_batch_fail3():
Test MixUpBatch op Test MixUpBatch op
We expect this to fail because label column is not passed to mixup_batch We expect this to fail because label column is not passed to mixup_batch
""" """
logger.info("test_mixup_batch_fail3")
# Original Images # Original Images
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
ds_original = ds_original.batch(5, drop_remainder=True) ds_original = ds_original.batch(5, drop_remainder=True)
@ -237,11 +287,41 @@ def test_mixup_batch_fail3():
error_message = "Both images and labels columns are required" error_message = "Both images and labels columns are required"
assert error_message in str(error.value) assert error_message in str(error.value)
def test_mixup_batch_fail4():
"""
Test MixUpBatch Fail 2
We expect this to fail because alpha is zero
"""
logger.info("test_mixup_batch_fail4")
# Original Images
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
ds_original = ds_original.batch(5)
images_original = np.array([])
for idx, (image, _) in enumerate(ds_original):
if idx == 0:
images_original = image
else:
images_original = np.append(images_original, image, axis=0)
# MixUp Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10)
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
with pytest.raises(ValueError) as error:
vision.MixUpBatch(0.0)
error_message = "Input is not within the required interval"
assert error_message in str(error.value)
if __name__ == "__main__": if __name__ == "__main__":
test_mixup_batch_success1(plot=True) test_mixup_batch_success1(plot=True)
test_mixup_batch_success2(plot=True) test_mixup_batch_success2(plot=True)
test_mixup_batch_success3(plot=True)
test_mixup_batch_md5() test_mixup_batch_md5()
test_mixup_batch_fail1() test_mixup_batch_fail1()
test_mixup_batch_fail2() test_mixup_batch_fail2()
test_mixup_batch_fail3() test_mixup_batch_fail3()
test_mixup_batch_fail4()

View File

@ -27,6 +27,7 @@ GENERATE_GOLDEN = False
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
MNIST_DATA_DIR = "../data/dataset/testMnistData"
def test_random_affine_op(plot=False): def test_random_affine_op(plot=False):
@ -155,6 +156,24 @@ def test_random_affine_c_md5():
ds.config.set_num_parallel_workers((original_num_parallel_workers)) ds.config.set_num_parallel_workers((original_num_parallel_workers))
def test_random_affine_py_exception_non_pil_images():
"""
Test RandomAffine: input img is ndarray and not PIL, expected to raise TypeError
"""
logger.info("test_random_affine_exception_negative_degrees")
dataset = ds.MnistDataset(MNIST_DATA_DIR, num_parallel_workers=3)
try:
transform = py_vision.ComposeOp([py_vision.ToTensor(),
py_vision.RandomAffine(degrees=(15, 15))])
dataset = dataset.map(input_columns=["image"], operations=transform(), num_parallel_workers=3,
python_multiprocessing=True)
for _ in dataset.create_dict_iterator():
break
except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Pillow image" in str(e)
def test_random_affine_exception_negative_degrees(): def test_random_affine_exception_negative_degrees():
""" """
Test RandomAffine: input degrees in negative, expected to raise ValueError Test RandomAffine: input degrees in negative, expected to raise ValueError
@ -289,6 +308,7 @@ if __name__ == "__main__":
test_random_affine_op_c(plot=True) test_random_affine_op_c(plot=True)
test_random_affine_md5() test_random_affine_md5()
test_random_affine_c_md5() test_random_affine_c_md5()
test_random_affine_py_exception_non_pil_images()
test_random_affine_exception_negative_degrees() test_random_affine_exception_negative_degrees()
test_random_affine_exception_translation_range() test_random_affine_exception_translation_range()
test_random_affine_exception_scale_value() test_random_affine_exception_scale_value()