forked from mindspore-Ecosystem/mindspore
!13689 opt warpctc performance
From: @gengdongjie Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @c_34
This commit is contained in:
commit
61b445ff41
|
@ -67,12 +67,6 @@ def transpose_hwc2whc(image):
|
|||
return image
|
||||
|
||||
|
||||
def transpose_hwc2chw(image):
|
||||
"""transpose image from HWC to CHW"""
|
||||
image = np.transpose(image, (2, 0, 1))
|
||||
return image
|
||||
|
||||
|
||||
def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'):
|
||||
"""
|
||||
create train or evaluation dataset for warpctc
|
||||
|
@ -93,14 +87,20 @@ def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_
|
|||
vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)),
|
||||
c.TypeCast(mstype.float16)
|
||||
]
|
||||
image_trans_gpu = [
|
||||
vc.Rescale(1.0 / 255.0, 0.0),
|
||||
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
|
||||
vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)),
|
||||
vc.HWC2CHW()
|
||||
]
|
||||
label_trans = [
|
||||
c.TypeCast(mstype.int32)
|
||||
]
|
||||
data_set = data_set.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8)
|
||||
if device_target == 'Ascend':
|
||||
data_set = data_set.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=transpose_hwc2whc, input_columns=["image"], num_parallel_workers=8)
|
||||
else:
|
||||
data_set = data_set.map(operations=transpose_hwc2chw, input_columns=["image"], num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=image_trans_gpu, input_columns=["image"], num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8)
|
||||
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
|
|
|
@ -123,7 +123,6 @@ class StackedRNNForGPU(nn.Cell):
|
|||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.cast(x, mstype.float32)
|
||||
x = self.transpose(x, (3, 0, 2, 1))
|
||||
x = self.reshape(x, (-1, self.batch_size, self.input_size))
|
||||
output, _ = self.lstm(x, (self.h, self.c))
|
||||
|
|
Loading…
Reference in New Issue