forked from mindspore-Ecosystem/mindspore
!1342 Bug fix on issue Core dump on GPU when train with lenet with AU
Merge pull request !1342 from Tinazhang/cc
This commit is contained in:
commit
39b9aedf68
|
@ -55,11 +55,11 @@ Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
|
|||
|
||||
// apply C++ ops (note: python OPs are not accepted)
|
||||
if (count == 1) {
|
||||
(**tensor_op).Compute(input, output);
|
||||
RETURN_IF_NOT_OK((**tensor_op).Compute(input, output));
|
||||
} else if (count % 2 == 0) {
|
||||
(**tensor_op).Compute(*output, even_out_ptr);
|
||||
RETURN_IF_NOT_OK((**tensor_op).Compute(*output, even_out_ptr));
|
||||
} else {
|
||||
(**tensor_op).Compute(even_out, output);
|
||||
RETURN_IF_NOT_OK((**tensor_op).Compute(even_out, output));
|
||||
}
|
||||
count++;
|
||||
}
|
||||
|
|
|
@ -226,6 +226,27 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "num_ops" in str(e)
|
||||
|
||||
def test_cpp_uniform_augment_random_crop_ut():
|
||||
batch_size=2
|
||||
cifar10_dir = "../data/dataset/testCifar10Data"
|
||||
ds1 = de.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
|
||||
|
||||
transforms_ua = [
|
||||
C.RandomCrop(size=[224, 224]),
|
||||
C.RandomHorizontalFlip()
|
||||
]
|
||||
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=1)
|
||||
ds1 = ds1.map(input_columns="image", operations=uni_aug)
|
||||
|
||||
# apply DatasetOps
|
||||
ds1 = ds1.batch(batch_size, drop_remainder=True, num_parallel_workers=1)
|
||||
num_batches = 0
|
||||
try:
|
||||
for data in ds1.create_dict_iterator():
|
||||
num_batches += 1
|
||||
except BaseException as e:
|
||||
assert "Crop size" in str(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_uniform_augment(num_ops=1)
|
||||
|
@ -233,3 +254,4 @@ if __name__ == "__main__":
|
|||
test_cpp_uniform_augment_exception_pyops(num_ops=1)
|
||||
test_cpp_uniform_augment_exception_large_numops(num_ops=6)
|
||||
test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0)
|
||||
test_cpp_uniform_augment_random_crop_ut()
|
||||
|
|
Loading…
Reference in New Issue