forked from mindspore-Ecosystem/mindspore
!5126 Fix problem in RandomPosterize & CutMixBatch
Merge pull request !5126 from luoyang/son_r0.7
This commit is contained in:
commit
cee889e426
|
@ -50,7 +50,7 @@ void CutMixBatchOp::GetCropBox(int height, int width, float lam, int *x, int *y,
|
|||
|
||||
Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||
if (input.size() < 2) {
|
||||
RETURN_STATUS_UNEXPECTED("Both images and labels columns are required for this operation");
|
||||
RETURN_STATUS_UNEXPECTED("Both images and labels columns are required for this operation.");
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<Tensor>> images;
|
||||
|
@ -59,10 +59,10 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
|
||||
// Check inputs
|
||||
if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) {
|
||||
RETURN_STATUS_UNEXPECTED("You must make sure images are HWC or CHW and batch before calling CutMixBatch.");
|
||||
RETURN_STATUS_UNEXPECTED("You must make sure images are HWC or CHW and batched before calling CutMixBatch.");
|
||||
}
|
||||
if (label_shape.size() != 2) {
|
||||
RETURN_STATUS_UNEXPECTED("CutMixBatch: Label's must be in one-hot format and in a batch");
|
||||
RETURN_STATUS_UNEXPECTED("CutMixBatch: Label's must be in one-hot format and in a batch.");
|
||||
}
|
||||
if ((image_shape[1] != 1 && image_shape[1] != 3) && image_batch_format_ == ImageBatchFormat::kNCHW) {
|
||||
RETURN_STATUS_UNEXPECTED("CutMixBatch: Image doesn't match the given image format.");
|
||||
|
|
|
@ -415,9 +415,7 @@ Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Te
|
|||
for (int i = 0; i < crop_width; i++) {
|
||||
for (int j = 0; j < crop_height; j++) {
|
||||
for (int c = 0; c < number_of_channels; c++) {
|
||||
uint8_t pixel_value;
|
||||
RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {j, i, c}));
|
||||
RETURN_IF_NOT_OK((*input)->SetItemAt({y + j, x + i, c}, pixel_value));
|
||||
RETURN_IF_NOT_OK(CopyTensorValue(sub_mat, input, {j, i, c}, {y + j, x + i, c}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -432,9 +430,7 @@ Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Te
|
|||
for (int i = 0; i < crop_width; i++) {
|
||||
for (int j = 0; j < crop_height; j++) {
|
||||
for (int c = 0; c < number_of_channels; c++) {
|
||||
uint8_t pixel_value;
|
||||
RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {c, j, i}));
|
||||
RETURN_IF_NOT_OK((*input)->SetItemAt({c, y + j, x + i}, pixel_value));
|
||||
RETURN_IF_NOT_OK(CopyTensorValue(sub_mat, input, {c, j, i}, {c, y + j, x + i}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -447,9 +443,7 @@ Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Te
|
|||
}
|
||||
for (int i = 0; i < crop_width; i++) {
|
||||
for (int j = 0; j < crop_height; j++) {
|
||||
uint8_t pixel_value;
|
||||
RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {j, i}));
|
||||
RETURN_IF_NOT_OK((*input)->SetItemAt({y + j, x + i}, pixel_value));
|
||||
RETURN_IF_NOT_OK(CopyTensorValue(sub_mat, input, {j, i}, {y + j, x + i}));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -458,6 +452,24 @@ Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Te
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CopyTensorValue(const std::shared_ptr<Tensor> &source_tensor, std::shared_ptr<Tensor> *dest_tensor,
|
||||
const std::vector<int64_t> &source_indx, const std::vector<int64_t> &dest_indx) {
|
||||
if (source_tensor->type() != (*dest_tensor)->type())
|
||||
RETURN_STATUS_UNEXPECTED("CopyTensorValue: source and destination tensor must have the same type.");
|
||||
if (source_tensor->type() == DataType::DE_UINT8) {
|
||||
uint8_t pixel_value;
|
||||
RETURN_IF_NOT_OK(source_tensor->GetItemAt(&pixel_value, source_indx));
|
||||
RETURN_IF_NOT_OK((*dest_tensor)->SetItemAt(dest_indx, pixel_value));
|
||||
} else if (source_tensor->type() == DataType::DE_FLOAT32) {
|
||||
float pixel_value;
|
||||
RETURN_IF_NOT_OK(source_tensor->GetItemAt(&pixel_value, source_indx));
|
||||
RETURN_IF_NOT_OK((*dest_tensor)->SetItemAt(dest_indx, pixel_value));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("CopyTensorValue: Tensor type is not supported. Tensor type must be float32 or uint8.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SwapRedAndBlue(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) {
|
||||
try {
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input));
|
||||
|
|
|
@ -133,6 +133,17 @@ Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output);
|
|||
Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Tensor> *input, int x, int y, int width,
|
||||
int height, ImageFormat image_format);
|
||||
|
||||
/// \brief Copies a value from a source tensor into a destination tensor
|
||||
/// \note This is meant for images and therefore only works if tensor is uint8 or float32
|
||||
/// \param[in] source_tensor The tensor we take the value from
|
||||
/// \param[in] dest_tensor The pointer to the tensor we want to copy the value to
|
||||
/// \param[in] source_indx index of the value in the source tensor
|
||||
/// \param[in] dest_indx index of the value in the destination tensor
|
||||
/// \param[out] dest_tensor Copies the value to the given dest_tensor and returns it
|
||||
/// @return Status ok/error
|
||||
Status CopyTensorValue(const std::shared_ptr<Tensor> &source_tensor, std::shared_ptr<Tensor> *dest_tensor,
|
||||
const std::vector<int64_t> &source_indx, const std::vector<int64_t> &dest_indx);
|
||||
|
||||
/// \brief Swap the red and blue pixels (RGB <-> BGR)
|
||||
/// \param input: Tensor of shape <H,W,3> and any OpenCv compatible type, see CVTensor.
|
||||
/// \param output: Swapped image of same shape and type
|
||||
|
|
|
@ -38,13 +38,13 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
|
||||
// Check inputs
|
||||
if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) {
|
||||
RETURN_STATUS_UNEXPECTED("You must make sure images are HWC or CHW and batch before calling MixUpBatch");
|
||||
RETURN_STATUS_UNEXPECTED("You must make sure images are HWC or CHW and batch before calling MixUpBatch.");
|
||||
}
|
||||
if (label_shape.size() != 2) {
|
||||
RETURN_STATUS_UNEXPECTED("MixUpBatch: Label's must be in one-hot format and in a batch");
|
||||
RETURN_STATUS_UNEXPECTED("MixUpBatch: Label's must be in one-hot format and in a batch.");
|
||||
}
|
||||
if ((image_shape[1] != 1 && image_shape[1] != 3) && (image_shape[3] != 1 && image_shape[3] != 3)) {
|
||||
RETURN_STATUS_UNEXPECTED("MixUpBatch: Images must be in the shape of HWC or CHW");
|
||||
RETURN_STATUS_UNEXPECTED("MixUpBatch: Images must be in the shape of HWC or CHW.");
|
||||
}
|
||||
|
||||
// Move images into a vector of CVTensors
|
||||
|
|
|
@ -40,6 +40,8 @@ Status PosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_pt
|
|||
}
|
||||
cv::Mat in_image = input_cv->mat();
|
||||
cv::Mat output_img;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(in_image.depth() == CV_8U || in_image.depth() == CV_8S,
|
||||
"Input image data type can not be float, but got " + input->type().ToString());
|
||||
cv::LUT(in_image, lut_vector, output_img);
|
||||
std::shared_ptr<CVTensor> result_tensor;
|
||||
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, &result_tensor));
|
||||
|
|
|
@ -76,7 +76,7 @@ def test_cutmix_batch_success1(plot=False):
|
|||
|
||||
def test_cutmix_batch_success2(plot=False):
|
||||
"""
|
||||
Test CutMixBatch op with default values for alpha and prob on a batch of HWC images
|
||||
Test CutMixBatch op with default values for alpha and prob on a batch of rescaled HWC images
|
||||
"""
|
||||
logger.info("test_cutmix_batch_success2")
|
||||
|
||||
|
@ -95,6 +95,8 @@ def test_cutmix_batch_success2(plot=False):
|
|||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
rescale_op = vision.Rescale((1.0/255.0), 0.0)
|
||||
data1 = data1.map(input_columns=["image"], operations=rescale_op)
|
||||
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
|
||||
|
|
|
@ -142,8 +142,29 @@ def test_random_posterize_exception_bit():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "Size of bits should be a single integer or a list/tuple (min, max) of length 2."
|
||||
|
||||
def test_rescale_with_random_posterize():
|
||||
"""
|
||||
Test RandomPosterize: only support CV_8S/CV_8U
|
||||
"""
|
||||
logger.info("test_rescale_with_random_posterize")
|
||||
|
||||
DATA_DIR_10 = "../data/dataset/testCifar10Data"
|
||||
dataset = ds.Cifar10Dataset(DATA_DIR_10)
|
||||
|
||||
rescale_op = c_vision.Rescale((1.0 / 255.0), 0.0)
|
||||
dataset = dataset.map(input_columns=["image"], operations=rescale_op)
|
||||
|
||||
random_posterize_op = c_vision.RandomPosterize((4, 8))
|
||||
dataset = dataset.map(input_columns=["image"], operations=random_posterize_op, num_parallel_workers=1)
|
||||
|
||||
try:
|
||||
_ = dataset.output_shapes()
|
||||
except RuntimeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input image data type can not be float" in str(e)
|
||||
|
||||
if __name__ == "__main__":
|
||||
skip_test_random_posterize_op_c(plot=True)
|
||||
skip_test_random_posterize_op_fixed_point_c(plot=True)
|
||||
test_random_posterize_exception_bit()
|
||||
test_rescale_with_random_posterize()
|
||||
|
|
Loading…
Reference in New Issue