fix ctpn BNTrainingUpdateGrad inconsistent data type

This commit is contained in:
zhouneng2 2021-08-07 16:17:46 +08:00
parent c84e09cf45
commit 2c999fcae5
2 changed files with 4 additions and 4 deletions

View File

@ -114,13 +114,13 @@ pretraining_dataset_file: ""
finetune_dataset_file: ""
# pretrain lr
pre_base_lr: 0.0009
pre_base_lr: 0.009
pre_warmup_step: 30000
pre_warmup_ratio: 1/3
pre_total_epoch: 100
# finetune lr
fine_base_lr: 0.0005
fine_base_lr: 0.005
fine_warmup_step: 300
fine_warmup_ratio: 1/3
fine_total_epoch: 50

View File

@ -92,8 +92,8 @@ class CTPN(nn.Cell):
self.num_step = config.num_step
self.input_size = config.input_size
self.hidden_size = config.hidden_size
self.vgg16_feature_extractor = VGG16FeatureExtraction()
self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same')
self.vgg16_feature_extractor = VGG16FeatureExtraction().to_float(mstype.float16)
self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same').to_float(mstype.float16)
self.rnn = BiLSTM(self.config, batch_size=self.batch_size).to_float(mstype.float16)
self.reshape = P.Reshape()
self.transpose = P.Transpose()