!8574 recitify warpctc script

From: @gengdongjie
Reviewed-by: @oacjiewen,@c_34
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2020-11-17 09:31:02 +08:00 committed by Gitee
commit 8658b2f3ed
4 changed files with 24 additions and 29 deletions

View File

@ -61,12 +61,18 @@ class _CaptchaDataset:
return image, label
def transpose_hwc2whc(image):
"""transpose image from HWC to WHC"""
image = np.transpose(image, (1, 0, 2))
return image
def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'):
"""
create train or evaluation dataset for warpctc
Args:
dataset_path(int): dataset path
dataset_path(str): dataset path
batch_size(int): batch size of generated dataset, default is 1
num_shards(int): number of devices
shard_id(int): rank id
@ -79,12 +85,13 @@ def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_
vc.Rescale(1.0 / 255.0, 0.0),
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)),
vc.HWC2CHW()
c.TypeCast(mstype.float16)
]
label_trans = [
c.TypeCast(mstype.int32)
]
ds = ds.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8)
ds = ds.map(operations=transpose_hwc2whc, input_columns=["image"], num_parallel_workers=8)
ds = ds.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8)
ds = ds.batch(batch_size, drop_remainder=True)

View File

@ -50,7 +50,7 @@ class WarpCTCAccuracy(nn.Metric):
def eval(self):
if self._total_num == 0:
raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.')
raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.')
return self._correct_num / self._total_num
def _is_eq(self, pred_lbl, target):

View File

@ -33,13 +33,13 @@ class StackedRNN(nn.Cell):
hidden_size(int): the hidden size in LSTM layers, default is 512
"""
def __init__(self, input_size, batch_size=64, hidden_size=512):
def __init__(self, input_size, batch_size=64, hidden_size=512, num_class=11):
super(StackedRNN, self).__init__()
self.batch_size = batch_size
self.input_size = input_size
self.num_classes = 11
self.reshape = P.Reshape()
self.cast = P.Cast()
self.batch_size = batch_size
self.hidden_size = hidden_size
self.num_class = num_class
k = (1 / hidden_size) ** 0.5
self.rnn1 = P.DynamicRNN(forget_bias=0.0)
@ -58,32 +58,23 @@ class StackedRNN(nn.Cell):
self.c1 = Tensor(np.zeros(shape=(1, batch_size, hidden_size)).astype(np.float16))
self.c2 = Tensor(np.zeros(shape=(1, batch_size, hidden_size)).astype(np.float16))
self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32)
self.fc_bias = np.random.random(self.num_classes).astype(np.float32)
self.fc_weight = Tensor(np.random.random((hidden_size, num_class)).astype(np.float16))
self.fc_bias = Tensor(np.random.random(self.num_class).astype(np.float16))
self.fc = nn.Dense(in_channels=hidden_size, out_channels=self.num_classes, weight_init=Tensor(self.fc_weight),
bias_init=Tensor(self.fc_bias))
self.fc.to_float(mstype.float32)
self.expand_dims = P.ExpandDims()
self.concat = P.Concat()
self.cast = P.Cast()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.squeeze = P.Squeeze(axis=0)
self.matmul = nn.MatMul()
def construct(self, x):
x = self.cast(x, mstype.float16)
x = self.transpose(x, (3, 0, 2, 1))
x = self.transpose(x, (1, 0, 2, 3))
x = self.reshape(x, (-1, self.batch_size, self.input_size))
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
y2, _, _, _, _, _, _, _ = self.rnn2(y1, self.w2, self.b2, None, self.h2, self.c2)
output = ()
for i in range(F.shape(x)[0]):
y2_after_fc = self.fc(self.squeeze(y2[i:i + 1:1]))
y2_after_fc = self.expand_dims(y2_after_fc, 0)
output += (y2_after_fc,)
output = self.concat(output)
# [time_step, bs, hidden_size] * [hidden_size, num_class] + [num_class]
output = self.matmul(y2, self.fc_weight) + self.fc_bias
return output

View File

@ -51,14 +51,11 @@ if args_opt.platform == 'Ascend':
if __name__ == '__main__':
lr_scale = 1
if args_opt.run_distribute:
init()
if args_opt.platform == 'Ascend':
init()
lr_scale = 1
device_num = int(os.environ.get("RANK_SIZE"))
rank = int(os.environ.get("RANK_ID"))
else:
init()
lr_scale = 1
device_num = get_group_size()
rank = get_rank()
context.reset_auto_parallel_context()