diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_invert_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_invert_op.cc index 4ba135c7188..9c18ccf7480 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_invert_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_invert_op.cc @@ -13,14 +13,30 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "minddata/dataset/kernels/image/random_invert_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" + namespace mindspore { namespace dataset { const float RandomInvertOp::kDefProbability = 0.5; Status RandomInvertOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); + // check input + if (input->Rank() != DEFAULT_IMAGE_RANK) { + RETURN_STATUS_UNEXPECTED("RandomInvert: image shape is not , got rank: " + std::to_string(input->Rank())); + } + if (input->shape()[CHANNEL_INDEX] != DEFAULT_IMAGE_CHANNELS) { + RETURN_STATUS_UNEXPECTED( + "RandomInvert: image shape is incorrect, expected num of channels is 3, " + "but got:" + + std::to_string(input->shape()[CHANNEL_INDEX])); + } + CHECK_FAIL_RETURN_UNEXPECTED(input->type().AsCVType() != kCVInvalidType, + "RandomInvert: Cannot convert from OpenCV type, unknown CV type. Currently " + "supported data type: [int8, uint8, int16, uint16, int32, float16, float32, float64]."); if (distribution_(rnd_)) { return InvertOp::Compute(input, output); } diff --git a/tests/ut/python/dataset/test_random_invert.py b/tests/ut/python/dataset/test_random_invert.py index d0b177f4396..9e9f1ebacb7 100644 --- a/tests/ut/python/dataset/test_random_invert.py +++ b/tests/ut/python/dataset/test_random_invert.py @@ -122,8 +122,76 @@ def test_random_invert_invalid_prob(): assert "Input prob is not within the required interval of [0.0, 1.0]." in str(e) +def test_random_invert_one_channel(): + """ + Feature: RandomInvert + Description: test with one channel images + Expectation: raise errors as expected + """ + logger.info("test_random_invert_one_channel") + + c_op = RandomInvert() + + try: + data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) + data_set = data_set.map(operations=[Decode(), Resize((224, 224)), lambda img: np.array(img[:, :, 0])], + input_columns=["image"]) + + data_set = data_set.map(operations=c_op, input_columns="image") + + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "image shape is incorrect, expected num of channels is 3." in str(e) + + +def test_random_invert_four_dim(): + """ + Feature: RandomInvert + Description: test with four dimension images + Expectation: raise errors as expected + """ + logger.info("test_random_invert_four_dim") + + c_op = RandomInvert() + + try: + data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) + data_set = data_set.map(operations=[Decode(), Resize((224, 224)), lambda img: np.array(img[2, 200, 10, 32])], + input_columns=["image"]) + + data_set = data_set.map(operations=c_op, input_columns="image") + + except ValueError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "image shape is not " in str(e) + + +def test_random_invert_invalid_input(): + """ + Feature: RandomInvert + Description: test with images in uint32 type + Expectation: raise errors as expected + """ + logger.info("test_random_invert_invalid_input") + + c_op = RandomInvert() + + try: + data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) + data_set = data_set.map(operations=[Decode(), Resize((224, 224)), + lambda img: np.array(img[2, 32, 3], dtype=uint32)], input_columns=["image"]) + data_set = data_set.map(operations=c_op, input_columns="image") + + except ValueError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Cannot convert from OpenCV type, unknown CV type" in str(e) + + if __name__ == "__main__": test_random_invert_pipeline(plot=True) test_random_invert_eager() test_random_invert_comp(plot=True) test_random_invert_invalid_prob() + test_random_invert_one_channel() + test_random_invert_four_dim() + test_random_invert_invalid_input()