forked from mindspore-Ecosystem/mindspore
!30339 [BUG][MD][FUNC]RandomInvert
Merge pull request !30339 from yangwm/randominvert
This commit is contained in:
commit
72ea797e4d
|
@ -13,14 +13,30 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/kernels/image/random_invert_op.h"
|
||||
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
const float RandomInvertOp::kDefProbability = 0.5;
|
||||
|
||||
Status RandomInvertOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
// check input
|
||||
if (input->Rank() != DEFAULT_IMAGE_RANK) {
|
||||
RETURN_STATUS_UNEXPECTED("RandomInvert: image shape is not <H,W,C>, got rank: " + std::to_string(input->Rank()));
|
||||
}
|
||||
if (input->shape()[CHANNEL_INDEX] != DEFAULT_IMAGE_CHANNELS) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"RandomInvert: image shape is incorrect, expected num of channels is 3, "
|
||||
"but got:" +
|
||||
std::to_string(input->shape()[CHANNEL_INDEX]));
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->type().AsCVType() != kCVInvalidType,
|
||||
"RandomInvert: Cannot convert from OpenCV type, unknown CV type. Currently "
|
||||
"supported data type: [int8, uint8, int16, uint16, int32, float16, float32, float64].");
|
||||
if (distribution_(rnd_)) {
|
||||
return InvertOp::Compute(input, output);
|
||||
}
|
||||
|
|
|
@ -122,8 +122,76 @@ def test_random_invert_invalid_prob():
|
|||
assert "Input prob is not within the required interval of [0.0, 1.0]." in str(e)
|
||||
|
||||
|
||||
def test_random_invert_one_channel():
|
||||
"""
|
||||
Feature: RandomInvert
|
||||
Description: test with one channel images
|
||||
Expectation: raise errors as expected
|
||||
"""
|
||||
logger.info("test_random_invert_one_channel")
|
||||
|
||||
c_op = RandomInvert()
|
||||
|
||||
try:
|
||||
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
data_set = data_set.map(operations=[Decode(), Resize((224, 224)), lambda img: np.array(img[:, :, 0])],
|
||||
input_columns=["image"])
|
||||
|
||||
data_set = data_set.map(operations=c_op, input_columns="image")
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "image shape is incorrect, expected num of channels is 3." in str(e)
|
||||
|
||||
|
||||
def test_random_invert_four_dim():
|
||||
"""
|
||||
Feature: RandomInvert
|
||||
Description: test with four dimension images
|
||||
Expectation: raise errors as expected
|
||||
"""
|
||||
logger.info("test_random_invert_four_dim")
|
||||
|
||||
c_op = RandomInvert()
|
||||
|
||||
try:
|
||||
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
data_set = data_set.map(operations=[Decode(), Resize((224, 224)), lambda img: np.array(img[2, 200, 10, 32])],
|
||||
input_columns=["image"])
|
||||
|
||||
data_set = data_set.map(operations=c_op, input_columns="image")
|
||||
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "image shape is not <H,W,C>" in str(e)
|
||||
|
||||
|
||||
def test_random_invert_invalid_input():
|
||||
"""
|
||||
Feature: RandomInvert
|
||||
Description: test with images in uint32 type
|
||||
Expectation: raise errors as expected
|
||||
"""
|
||||
logger.info("test_random_invert_invalid_input")
|
||||
|
||||
c_op = RandomInvert()
|
||||
|
||||
try:
|
||||
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
data_set = data_set.map(operations=[Decode(), Resize((224, 224)),
|
||||
lambda img: np.array(img[2, 32, 3], dtype=uint32)], input_columns=["image"])
|
||||
data_set = data_set.map(operations=c_op, input_columns="image")
|
||||
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Cannot convert from OpenCV type, unknown CV type" in str(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_random_invert_pipeline(plot=True)
|
||||
test_random_invert_eager()
|
||||
test_random_invert_comp(plot=True)
|
||||
test_random_invert_invalid_prob()
|
||||
test_random_invert_one_channel()
|
||||
test_random_invert_four_dim()
|
||||
test_random_invert_invalid_input()
|
||||
|
|
Loading…
Reference in New Issue