2020-03-27 14:49:12 +08:00
|
|
|
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# ==============================================================================
|
|
|
|
import numpy as np
|
2020-05-18 16:42:35 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
import mindspore.dataset as ds
|
|
|
|
import mindspore.dataset.transforms.c_transforms as c
|
2020-05-18 16:42:35 +08:00
|
|
|
import mindspore.dataset.transforms.py_transforms as f
|
2020-08-28 03:30:21 +08:00
|
|
|
import mindspore.dataset.vision.c_transforms as c_vision
|
|
|
|
import mindspore.dataset.vision.py_transforms as py_vision
|
2020-03-27 14:49:12 +08:00
|
|
|
from mindspore import log as logger
|
|
|
|
|
|
|
|
DATA_DIR = "../data/dataset/testImageNetData/train"
|
|
|
|
DATA_DIR_2 = "../data/dataset/testImageNetData2/train"
|
|
|
|
|
|
|
|
|
|
|
|
def test_one_hot_op():
|
|
|
|
"""
|
|
|
|
Test one hot encoding op
|
|
|
|
"""
|
|
|
|
logger.info("Test one hot encoding op")
|
|
|
|
|
|
|
|
# define map operations
|
|
|
|
# ds = de.ImageFolderDataset(DATA_DIR, schema=SCHEMA_DIR)
|
2020-08-28 03:30:21 +08:00
|
|
|
dataset = ds.ImageFolderDataset(DATA_DIR)
|
2020-03-27 14:49:12 +08:00
|
|
|
num_classes = 2
|
|
|
|
epsilon_para = 0.1
|
|
|
|
|
2020-09-10 01:23:02 +08:00
|
|
|
transforms = [f.OneHotOp(num_classes=num_classes, smoothing_rate=epsilon_para)]
|
2020-08-28 03:30:21 +08:00
|
|
|
transform_label = f.Compose(transforms)
|
2020-09-10 01:23:02 +08:00
|
|
|
dataset = dataset.map(operations=transform_label, input_columns=["label"])
|
2020-03-27 14:49:12 +08:00
|
|
|
|
2020-05-18 10:31:46 +08:00
|
|
|
golden_label = np.ones(num_classes) * epsilon_para / num_classes
|
|
|
|
golden_label[1] = 1 - epsilon_para / num_classes
|
2020-03-27 14:49:12 +08:00
|
|
|
|
2020-09-05 10:56:38 +08:00
|
|
|
for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
2020-03-27 14:49:12 +08:00
|
|
|
label = data["label"]
|
|
|
|
logger.info("label is {}".format(label))
|
|
|
|
logger.info("golden_label is {}".format(golden_label))
|
2020-05-22 14:16:07 +08:00
|
|
|
assert label.all() == golden_label.all()
|
2020-03-27 14:49:12 +08:00
|
|
|
logger.info("====test one hot op ok====")
|
|
|
|
|
|
|
|
|
|
|
|
def test_mix_up_single():
|
|
|
|
"""
|
|
|
|
Test single batch mix up op
|
|
|
|
"""
|
|
|
|
logger.info("Test single batch mix up op")
|
|
|
|
|
|
|
|
resize_height = 224
|
|
|
|
resize_width = 224
|
|
|
|
|
|
|
|
# Create dataset and define map operations
|
2020-08-28 03:30:21 +08:00
|
|
|
ds1 = ds.ImageFolderDataset(DATA_DIR_2)
|
2020-03-27 14:49:12 +08:00
|
|
|
|
|
|
|
num_classes = 10
|
|
|
|
decode_op = c_vision.Decode()
|
|
|
|
resize_op = c_vision.Resize((resize_height, resize_width), c_vision.Inter.LINEAR)
|
2020-05-18 10:31:46 +08:00
|
|
|
one_hot_encode = c.OneHot(num_classes) # num_classes is input argument
|
2020-03-27 14:49:12 +08:00
|
|
|
|
2020-09-10 01:23:02 +08:00
|
|
|
ds1 = ds1.map(operations=decode_op, input_columns=["image"])
|
|
|
|
ds1 = ds1.map(operations=resize_op, input_columns=["image"])
|
|
|
|
ds1 = ds1.map(operations=one_hot_encode, input_columns=["label"])
|
2020-03-27 14:49:12 +08:00
|
|
|
|
|
|
|
# apply batch operations
|
|
|
|
batch_size = 3
|
|
|
|
ds1 = ds1.batch(batch_size, drop_remainder=True)
|
|
|
|
|
|
|
|
ds2 = ds1
|
|
|
|
alpha = 0.2
|
|
|
|
transforms = [py_vision.MixUp(batch_size=batch_size, alpha=alpha, is_single=True)
|
|
|
|
]
|
2020-09-10 01:23:02 +08:00
|
|
|
ds1 = ds1.map(operations=transforms, input_columns=["image", "label"])
|
2020-03-27 14:49:12 +08:00
|
|
|
|
2020-09-05 10:56:38 +08:00
|
|
|
for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
|
|
|
ds2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
2020-03-27 14:49:12 +08:00
|
|
|
image1 = data1["image"]
|
|
|
|
label = data1["label"]
|
|
|
|
logger.info("label is {}".format(label))
|
|
|
|
|
|
|
|
image2 = data2["image"]
|
|
|
|
label2 = data2["label"]
|
|
|
|
logger.info("label2 is {}".format(label2))
|
|
|
|
|
|
|
|
lam = np.abs(label - label2)
|
2020-05-18 10:31:46 +08:00
|
|
|
for index in range(batch_size - 1):
|
2020-03-27 14:49:12 +08:00
|
|
|
if np.square(lam[index]).mean() != 0:
|
2020-05-18 10:31:46 +08:00
|
|
|
lam_value = 1 - np.sum(lam[index]) / 2
|
|
|
|
img_golden = lam_value * image2[index] + (1 - lam_value) * image2[index + 1]
|
2020-03-27 14:49:12 +08:00
|
|
|
assert image1[index].all() == img_golden.all()
|
|
|
|
logger.info("====test single batch mixup ok====")
|
|
|
|
|
|
|
|
|
|
|
|
def test_mix_up_multi():
|
|
|
|
"""
|
|
|
|
Test multi batch mix up op
|
|
|
|
"""
|
|
|
|
logger.info("Test several batch mix up op")
|
|
|
|
|
|
|
|
resize_height = 224
|
|
|
|
resize_width = 224
|
|
|
|
|
|
|
|
# Create dataset and define map operations
|
2020-08-28 03:30:21 +08:00
|
|
|
ds1 = ds.ImageFolderDataset(DATA_DIR_2)
|
2020-03-27 14:49:12 +08:00
|
|
|
|
|
|
|
num_classes = 3
|
|
|
|
decode_op = c_vision.Decode()
|
|
|
|
resize_op = c_vision.Resize((resize_height, resize_width), c_vision.Inter.LINEAR)
|
2020-05-18 10:31:46 +08:00
|
|
|
one_hot_encode = c.OneHot(num_classes) # num_classes is input argument
|
2020-03-27 14:49:12 +08:00
|
|
|
|
2020-09-10 01:23:02 +08:00
|
|
|
ds1 = ds1.map(operations=decode_op, input_columns=["image"])
|
|
|
|
ds1 = ds1.map(operations=resize_op, input_columns=["image"])
|
|
|
|
ds1 = ds1.map(operations=one_hot_encode, input_columns=["label"])
|
2020-03-27 14:49:12 +08:00
|
|
|
|
|
|
|
# apply batch operations
|
|
|
|
batch_size = 3
|
|
|
|
ds1 = ds1.batch(batch_size, drop_remainder=True)
|
|
|
|
|
|
|
|
ds2 = ds1
|
|
|
|
alpha = 0.2
|
|
|
|
transforms = [py_vision.MixUp(batch_size=batch_size, alpha=alpha, is_single=False)
|
|
|
|
]
|
2020-09-10 01:23:02 +08:00
|
|
|
ds1 = ds1.map(operations=transforms, input_columns=["image", "label"])
|
2020-03-27 14:49:12 +08:00
|
|
|
num_iter = 0
|
|
|
|
batch1_image1 = 0
|
2020-09-05 10:56:38 +08:00
|
|
|
for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
|
|
|
ds2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
2020-03-27 14:49:12 +08:00
|
|
|
image1 = data1["image"]
|
|
|
|
label1 = data1["label"]
|
|
|
|
logger.info("label: {}".format(label1))
|
|
|
|
|
|
|
|
image2 = data2["image"]
|
|
|
|
label2 = data2["label"]
|
|
|
|
logger.info("label2: {}".format(label2))
|
|
|
|
|
|
|
|
if num_iter == 0:
|
|
|
|
batch1_image1 = image1
|
|
|
|
|
|
|
|
if num_iter == 1:
|
|
|
|
lam = np.abs(label2 - label1)
|
|
|
|
logger.info("lam value in multi: {}".format(lam))
|
|
|
|
for index in range(batch_size):
|
|
|
|
if np.square(lam[index]).mean() != 0:
|
2020-05-18 10:31:46 +08:00
|
|
|
lam_value = 1 - np.sum(lam[index]) / 2
|
|
|
|
img_golden = lam_value * image2[index] + (1 - lam_value) * batch1_image1[index]
|
2020-03-27 14:49:12 +08:00
|
|
|
assert image1[index].all() == img_golden.all()
|
|
|
|
logger.info("====test several batch mixup ok====")
|
|
|
|
break
|
2021-06-15 04:11:36 +08:00
|
|
|
num_iter += 1
|
2020-03-27 14:49:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
test_one_hot_op()
|
|
|
|
test_mix_up_single()
|
|
|
|
test_mix_up_multi()
|