forked from mindspore-Ecosystem/mindspore
Added dimension check to required Offload ops.
This commit is contained in:
parent
e4438f3028
commit
fac68dbb46
|
@ -22,6 +22,7 @@ from mindspore.common.tensor import Tensor
|
|||
import mindspore.nn as nn
|
||||
import mindspore.ops.composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
||||
|
||||
def check_concat_zip_dataset(dataset):
|
||||
|
@ -53,6 +54,17 @@ def apply_offload_iterators(data, offload_model):
|
|||
return data
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_input_dims(x_shape, required_dim, offload_op_name):
|
||||
"""
|
||||
Check if input has the required number of dimensions for the operation.
|
||||
"""
|
||||
input_dim = len(x_shape)
|
||||
if input_dim is not required_dim:
|
||||
raise ValueError("For %s offload operation, the dimension of input should be %d, but got %d." %
|
||||
(offload_op_name, required_dim, input_dim))
|
||||
|
||||
|
||||
class ApplyPreTransform(nn.Cell):
|
||||
"""
|
||||
Concatenates offload model with network.
|
||||
|
@ -99,7 +111,9 @@ class RandomHorizontalFlip(nn.Cell):
|
|||
def construct(self, x):
|
||||
|
||||
x = self.cast(x, mstype.float32)
|
||||
bs, h, w, c = self.shape(x)
|
||||
x_shape = self.shape(x)
|
||||
check_input_dims(x_shape, 4, 'RandomHorizontalFlip')
|
||||
bs, h, w, c = x_shape
|
||||
|
||||
flip_rand_factor = self.uniformReal((bs, 1))
|
||||
flip_rand_factor = self.cast((self.prob > flip_rand_factor), mstype.float32)
|
||||
|
@ -130,7 +144,9 @@ class RandomVerticalFlip(nn.Cell):
|
|||
def construct(self, x):
|
||||
|
||||
x = self.cast(x, mstype.float32)
|
||||
bs, h, w, c = self.shape(x)
|
||||
x_shape = self.shape(x)
|
||||
check_input_dims(x_shape, 4, 'RandomVerticalFlip')
|
||||
bs, h, w, c = x_shape
|
||||
|
||||
flip_rand_factor = self.uniformReal((bs, 1))
|
||||
flip_rand_factor = self.cast((self.prob > flip_rand_factor), mstype.float32)
|
||||
|
@ -174,7 +190,9 @@ class RandomColorAdjust(nn.Cell):
|
|||
def construct(self, x):
|
||||
|
||||
x = self.cast(x, mstype.float32)
|
||||
bs, h, w, c = self.shape(x)
|
||||
x_shape = self.shape(x)
|
||||
check_input_dims(x_shape, 4, 'RandomColorAdjust')
|
||||
bs, h, w, c = x_shape
|
||||
|
||||
br_rand_factor = self.br_min + (self.br_max - self.br_min)*self.uniformReal((bs, 1))
|
||||
br_rand_factor = self.reshape(C.repeat_elements(br_rand_factor, rep=(h*w*c)), (bs, h, w, c))
|
||||
|
@ -226,7 +244,9 @@ class RandomSharpness(nn.Cell):
|
|||
def construct(self, x):
|
||||
|
||||
x = self.cast(x, mstype.float32)
|
||||
bs, h, w, c = self.shape(x)
|
||||
x_shape = self.shape(x)
|
||||
check_input_dims(x_shape, 4, 'RandomSharpness')
|
||||
bs, h, w, c = x_shape
|
||||
|
||||
degree_rand_factor = self.degree_min + (self.degree_max - self.degree_min)*self.uniformReal((bs, 1))
|
||||
degree_rand_factor = self.reshape(C.repeat_elements(degree_rand_factor, rep=(h*w*c)), (bs, h, w, c))
|
||||
|
@ -268,8 +288,11 @@ class HwcToChw(nn.Cell):
|
|||
def __init__(self):
|
||||
super(HwcToChw, self).__init__()
|
||||
self.trans = P.Transpose()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x):
|
||||
x_shape = self.shape(x)
|
||||
check_input_dims(x_shape, 4, 'HwcToChw')
|
||||
return self.trans(x, (0, 3, 1, 2))
|
||||
|
||||
|
||||
|
|
|
@ -223,6 +223,23 @@ def test_offload_not_end_of_pipeline():
|
|||
np.testing.assert_(data_iterator.offload_model is None)
|
||||
|
||||
|
||||
def test_offload_dim_check():
|
||||
"""
|
||||
Feature: test input has the required number of dimensions for offload operation.
|
||||
Description: Input is image dataset.
|
||||
Expectation: Should raise ValueError.
|
||||
"""
|
||||
# Dataset with offload activated.
|
||||
dataset = ds.ImageFolderDataset(DATA_DIR)
|
||||
dataset = dataset.map(operations=[C.Decode()], input_columns="image")
|
||||
dataset = dataset.map(operations=[C.HWC2CHW()], input_columns="image", offload=True)
|
||||
|
||||
error_msg = "For HwcToChw offload operation, the dimension of input should be 4, but got 3."
|
||||
with pytest.raises(ValueError, match=error_msg):
|
||||
for (_, _) in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
continue
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_offload()
|
||||
test_auto_offload()
|
||||
|
@ -232,3 +249,4 @@ if __name__ == "__main__":
|
|||
test_offload_rescale_op()
|
||||
test_offload_different_column_end_of_pipeline()
|
||||
test_offload_not_end_of_pipeline()
|
||||
test_offload_dim_check()
|
||||
|
|
Loading…
Reference in New Issue