!30339 [BUG][MD][FUNC]RandomInvert

Merge pull request !30339 from yangwm/randominvert
This commit is contained in:
i-robot 2022-03-16 06:39:53 +00:00 committed by Gitee
commit 72ea797e4d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 84 additions and 0 deletions

View File

@ -13,14 +13,30 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/random_invert_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
namespace mindspore {
namespace dataset {
const float RandomInvertOp::kDefProbability = 0.5;
Status RandomInvertOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
// check input
if (input->Rank() != DEFAULT_IMAGE_RANK) {
RETURN_STATUS_UNEXPECTED("RandomInvert: image shape is not <H,W,C>, got rank: " + std::to_string(input->Rank()));
}
if (input->shape()[CHANNEL_INDEX] != DEFAULT_IMAGE_CHANNELS) {
RETURN_STATUS_UNEXPECTED(
"RandomInvert: image shape is incorrect, expected num of channels is 3, "
"but got:" +
std::to_string(input->shape()[CHANNEL_INDEX]));
}
CHECK_FAIL_RETURN_UNEXPECTED(input->type().AsCVType() != kCVInvalidType,
"RandomInvert: Cannot convert from OpenCV type, unknown CV type. Currently "
"supported data type: [int8, uint8, int16, uint16, int32, float16, float32, float64].");
if (distribution_(rnd_)) {
return InvertOp::Compute(input, output);
}

View File

@ -122,8 +122,76 @@ def test_random_invert_invalid_prob():
assert "Input prob is not within the required interval of [0.0, 1.0]." in str(e)
def test_random_invert_one_channel():
"""
Feature: RandomInvert
Description: test with one channel images
Expectation: raise errors as expected
"""
logger.info("test_random_invert_one_channel")
c_op = RandomInvert()
try:
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
data_set = data_set.map(operations=[Decode(), Resize((224, 224)), lambda img: np.array(img[:, :, 0])],
input_columns=["image"])
data_set = data_set.map(operations=c_op, input_columns="image")
except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "image shape is incorrect, expected num of channels is 3." in str(e)
def test_random_invert_four_dim():
"""
Feature: RandomInvert
Description: test with four dimension images
Expectation: raise errors as expected
"""
logger.info("test_random_invert_four_dim")
c_op = RandomInvert()
try:
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
data_set = data_set.map(operations=[Decode(), Resize((224, 224)), lambda img: np.array(img[2, 200, 10, 32])],
input_columns=["image"])
data_set = data_set.map(operations=c_op, input_columns="image")
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "image shape is not <H,W,C>" in str(e)
def test_random_invert_invalid_input():
"""
Feature: RandomInvert
Description: test with images in uint32 type
Expectation: raise errors as expected
"""
logger.info("test_random_invert_invalid_input")
c_op = RandomInvert()
try:
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
data_set = data_set.map(operations=[Decode(), Resize((224, 224)),
lambda img: np.array(img[2, 32, 3], dtype=uint32)], input_columns=["image"])
data_set = data_set.map(operations=c_op, input_columns="image")
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Cannot convert from OpenCV type, unknown CV type" in str(e)
if __name__ == "__main__":
test_random_invert_pipeline(plot=True)
test_random_invert_eager()
test_random_invert_comp(plot=True)
test_random_invert_invalid_prob()
test_random_invert_one_channel()
test_random_invert_four_dim()
test_random_invert_invalid_input()