revert 'Pull Request !21685 : add function to invert_op'

This commit is contained in:
guozhijian 2021-08-19 06:53:33 +00:00 committed by jonyguo
parent c93bb9936e
commit e1df9a2152
6 changed files with 14 additions and 55 deletions

View File

@ -32,27 +32,22 @@ Status InvertOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("Invert: load image failed.");
}
if (input_cv->Rank() != 3 && input_cv->Rank() != 2) {
RETURN_STATUS_UNEXPECTED("Invert: image shape is not <H,W,C> or <H,W>");
if (input_cv->Rank() != 3) {
RETURN_STATUS_UNEXPECTED("Invert: image shape is not <H,W,C>");
}
int num_channels = 1;
if (input_cv->Rank() == 3) {
num_channels = input_cv->shape()[2];
}
if (num_channels != 3 && num_channels != 1) {
RETURN_STATUS_UNEXPECTED("Invert: image shape is incorrect, expected num of channels is 1 or 3, but got:" +
int num_channels = input_cv->shape()[2];
if (num_channels != 3) {
RETURN_STATUS_UNEXPECTED(
"Invert: image shape is incorrect, expected num of channels is 3, "
"but got:" +
std::to_string(num_channels));
}
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv));
RETURN_UNEXPECTED_IF_NULL(output_cv);
if(num_channels == 3){
output_cv->mat() = cv::Scalar::all(255) - input_img;
}
else{
output_cv->mat() = cv::Scalar(255) - input_img;
}
output_cv->mat() = cv::Scalar::all(255) - input_img;
*output = std::static_pointer_cast<Tensor>(output_cv);
}
@ -63,4 +58,3 @@ Status InvertOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
}
} // namespace dataset
} // namespace mindspore

Binary file not shown.

Before

Width:  |  Height:  |  Size: 403 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 403 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 403 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 403 KiB

View File

@ -25,8 +25,6 @@ from mindspore import log as logger
from util import visualize_list, save_and_check_md5, diff_mse
DATA_DIR = "../data/dataset/testImageNetData/train/"
DATA_DIR0 = "../data/dataset/testImageNetData/train0/"
DATA_DIR1 = "../data/dataset/testImageNetData/train1/"
GENERATE_GOLDEN = False
@ -203,7 +201,8 @@ def test_invert_py_c(plot=False):
if plot:
visualize_list(images_c_invert, images_p_invert, visualize_mode=2)
def test_invert_one_channel(plot=False):
def test_invert_one_channel():
"""
Test Invert cpp op with one channel image
"""
@ -212,51 +211,17 @@ def test_invert_one_channel(plot=False):
c_op = C.Invert()
try:
data_set0 = ds.ImageFolderDataset(dataset_dir=DATA_DIR0, shuffle=False)
data_set0 = data_set0.map(operations=[C.Decode(), C.Resize((224, 224)),
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
data_set = data_set.map(operations=[C.Decode(), C.Resize((224, 224)),
lambda img: np.array(img[:, :, 0])], input_columns=["image"])
data_set1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR1, shuffle=False)
data_set1 = data_set1.map(operations=[C.Decode(), C.Resize((224, 224)),
lambda img: np.array(img[:, :, 0])], input_columns=["image"])
data_uninvert = data_set0
data_uninvert = data_uninvert.batch(512)
for idx, (image, _) in enumerate(data_uninvert):
if idx == 0:
images_0_invert = image.asnumpy()
else:
images_0_invert = np.append(images_0_invert,
image.asnumpy(),
axis=0)
data_invert1 = data_set1.map(operations=c_op, input_columns="image")
data_invert1 = data_invert1.batch(512)
for idx, (image, _) in enumerate(data_invert1):
if idx == 0:
images_1_invert = image.asnumpy()
else:
images_1_invert = np.append(images_1_invert,
image.asnumpy(),
axis=0)
num_samples = images_0_invert.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_0_invert[i], images_1_invert[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
if plot:
visualize_list(images_0_invert, images_1_invert, visualize_mode=2)
data_set.map(operations=c_op, input_columns="image")
except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "The shape" in str(e)
def test_invert_md5_py():
"""
Test Invert python op with md5 check